Building the autograd graph comes with additional costs, especially when the model has millions of parameters. In order to address this, PyTorch allows us to switch off autograd when we don’t need it, using the
torch.no_grad context manager. We won’t see any meaningful advantage in terms of speed or memory consumption on our small problem.
First, we encapsulate the tensor in a
no_grad context using Python
with statements. This means within the
with block, the PyTorch autograd mechanism should look away. That is, do not add edges to the forward graph. In fact, when we are executing this bit of code, the forward graph that PyTorch records is consumed when we call
backward, leaving us with the params leaf node.
x=torch.ones(2, requires_grad=True) with torch.no_grad(): y = x * 2 print(y.requires_grad) #False
As we are not training anything, we’ll tell PyTorch that we will not want gradients when running the code by running in a
with torch.no_grad() block. Using the
torch.no_grad() context manager is strictly a runtime switch.
In addition, it’s quite a bit faster due to the
with torch.no_grad() context manager explicitly informing PyTorch that no gradients need to be computed. Context managers like with
torch.no_grad(): can be used to control auto-grad behavior.
tensor.detach() creates a tensor that shares storage with a tensor that does not require grad. You should use
detach() when attempting to remove a tensor from a computation graph.
tensor .detach() detaches the Tensor from the graph that created it, making it a leaf.
from torchviz import make_dot x=torch.tensor(2.0, requires_grad=True) y=2*x z=3+x r=(y+z).sum() make_dot(r) #Detach x=torch.tensor(2.0, requires_grad=True) y=2*x z=3+x.detach() r=(y+z).sum() make_dot(r)
z=3+x detached from the computational graph by
detach() operates on a tensor and returns the same tensor, which will be detached from the computation graph at this point, so that the backward pass will stop at this point.
tensor.detach() is used when this specification has to be provided for a limited number of variables or functions for eg. generally while displaying the loss and accuracy outputs after an epoch ends in neural network training because, at that moment, it only consumed resourced since its gradient won’t matter in during the display of results.
We should not think that using
torch.no_grad necessarily implies that the outputs do not require gradients. There are particular circumstances in which
requires_grad is not set to False even when created in a
no_grad context. It is best to use the
detach function if we need to be sure.