tensor.detach() creates a tensor that shares storage with a tensor that does not require grad. Shares storage means if I detach a tensor from its original and modify it, then both would change.

x = torch.tensor(([1.0]),requires_grad=True)
y = x.detach()

print(y.requires_grad) #False

y[0]=2.0
print(x) #tensor([2.], requires_grad=True)

So both change because they share storage. detach() gives a new tensor that is a view of the original one. So any in-place modification of one will affect the other. Use tensor.detach().clone() if you want a new Tensor backward with new memory that does not share the autograd history of the original one.

In-place modification

Detach tensor is a view of the “original one” it’s means a new tensor instance internally has a pointer/reference to the data of the other tensor where the actual memory with the data lies in.

x.untyped_storage().tolist()==y.untyped_storage().tolist() #True

tensor.detach() detaches the tensor history and does not require gradients. So there is nothing wrong with modifying this tensor in-place.

tensor.detach() and torch.no_grad() is to be able to do in-place modification. If you want your ops to be differentiable, you shouldn’t do the in-place modification. It’s modifying the original tensor in the graph.

In-place operation

If you change the values of the tensor in-place and then use it, PyTorch doesn’t consider that to be an error. If you change some values while explicitly hiding them from the autograd with .detach(), PyTorch assumes you have a good reason to do so.

The forward pass needs to save some tensor values to be able to compute the backward pass. If you modify one of these saved tensors before running the backward, then in that case, you will get an error., because the original value was needed to compute the right gradients and it does not exist anymore (was modified in-place). For example here, the output of exp() is required in the backward, so if we modify it in-place, you get an error:

a = torch.tensor([2.0,3.0], requires_grad=True)

b = a.exp()

print("b",b) #tensor([ 7.3891, 20.0855], grad_fn=<ExpBackward0>)

print("b - inplace",b.fill_(2)) #inplace tensor([2., 2.], grad_fn=<FillBackward2>)

make_dot(b)
PyTorch inplace modification
b.sum().backward()

RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.FloatTensor [2]], which is output 0 of ExpBackward0, is at version 1; expected version 0 instead. Hint: enable anomaly detection to find the operation that failed to compute its gradient, with torch.autograd.set_detect_anomaly(True).

To compute the backward of exp(a), we need to compute grad_out * exp(a). So we re-use the result of the forward instead of recomputing it for performance reasons.

If you try to backward you will get an error. It is just harder to detect during the forward. The .backward() did NOT catch the in-place operation on a tensor that is in the forward computation graph. It only detects it during the backward.

Clone tensor

.clone() is useful to create a copy of the original variable that doesn’t forget the history of ops so it allows gradient flow and avoids errors with inlace ops. The main error of in-place ops is overwriting data needed for the backward pass or writing an in-place op to a leaf node, in this case, there would be no error message.

a = torch.tensor([2.0,2,0,], requires_grad=True)
a_clone = a.clone()
    
print(f'a = {a}')   #tensor([2., 2., 0.], requires_grad=True)
print(f'a_clone = {a_clone}') #tensor([2., 2., 0.], grad_fn=<CloneBackward0>)

a_clone.mul_(2)
print(f'a = {a}') #tensor([2., 2., 0.], requires_grad=True)
print(f'a_clone = {a_clone}') #tensor([4., 4., 0.], grad_fn=<MulBackward0>)
a_clone.sum().backward()

Because the output of the clone is not required during the backward pass. So there is no reason to through an error if it was changed in-place. For clone, the backward just needs to compute grad_out (a no-op) and so no need to save the value of any tensor from the forward. All the ops (except if you use .detach() or torch.no_grad()) are recorded. Otherwise, we wouldn’t be able to compute the correct gradients.

Related Post