Building the autograd graph comes with additional costs, especially when the model has millions of parameters. In order to address this, PyTorch allows us to switch off autograd when we don’t need it, using the torch.no_grad context manager. We won’t see any meaningful advantage in terms of speed or memory consumption on our small problem.

First, we encapsulate the tensor in a no_grad context using Python with statements. This means within the with block, the PyTorch autograd mechanism should look away. That is, do not add edges to the forward graph. In fact, when we are executing this bit of code, the forward graph that PyTorch records is consumed when we call backward, leaving us with the params leaf node.

x=torch.ones(2, requires_grad=True)
with torch.no_grad():
    y = x * 2
print(y.requires_grad) #False

As we are not training anything, we’ll tell PyTorch that we will not want gradients when running the code by running in a with torch.no_grad() block. Using the torch.no_grad() context manager is strictly a runtime switch.

In addition, it’s quite a bit faster due to the with torch.no_grad() context manager explicitly informing PyTorch that no gradients need to be computed. Context managers like with torch.no_grad(): can be used to control auto-grad behavior.

tensor.detach()

tensor.detach() creates a tensor that shares storage with a tensor that does not require grad. You should use detach() when attempting to remove a tensor from a computation graph.tensor .detach() detaches the Tensor from the graph that created it, making it a leaf.  

from torchviz import make_dot

x=torch.tensor(2.0, requires_grad=True)
y=2*x
z=3+x
r=(y+z).sum()    
make_dot(r)

#Detach
x=torch.tensor(2.0, requires_grad=True)
y=2*x
z=3+x.detach()
r=(y+z).sum()    
make_dot(r)

z=3+x detached from the computational graph by detach().

PyTorch Detach

detach() operates on a tensor and returns the same tensor, which will be detached from the computation graph at this point, so that the backward pass will stop at this point.

tensor.detach() is used when this specification has to be provided for a limited number of variables or functions for eg. generally while displaying the loss and accuracy outputs after an epoch ends in neural network training because, at that moment, it only consumed resourced since its gradient won’t matter in during the display of results.

We should not think that using torch.no_grad necessarily implies that the outputs do not require gradients. There are particular circumstances in which requires_grad is not set to False even when created in a no_grad context. It is best to use the detach function if we need to be sure.

Related Post