Learning ways to manipulate tensor dimensions is necessary to make them compatible for input to a model or an operation. In this tutorial, you will learn how to add tensor dimension via torch.unsqueeze() functions that unsqueeze(add) dimensions. 

Certain operations require that the input tensors have a certain number of dimensions (that is, rank) associated with a certain number of elements (shape). Thus, we might need to change the dimensions of a tensor, add a new dimension, or squeeze an unnecessary dimension.

PyTorch provides useful functions (or operations) to achieve this, such as torch.unsqueeze(), and torch.squeeze(). Let’s take a look at some examples:

import torch


print(x.shape) #torch.Size([6, 2])


print(x_unsqueeze.shape) #torch.Size([1, 6, 2])

torch.Size([6, 2])
torch.Size([1, 6, 2])
tensor([[[7, 9],
         [4, 3],
         [1, 6],
         [8, 5],
         [2, 4],
         [5, 8]]], dtype=torch.int32)

The call to unsqueeze adds a singleton dimension, from a 2D tensor of shape [6,2] to a 3D tensor of shape[1,6,2], without changing its contents—no extra elements are added.

syntax : torch.unsqueeze(input, dim) → Tensor

torch.unsqueeze() returns a new tensor with a dimension of size one inserted at the specified position. A dim value within the range [-input.dim() - 1, input.dim() + 1) can be used. Negative dim will correspond to unsqueeze() applied at dim = dim + input.dim() + 1.

We need to use an extra index to access the elements. That is, we access the first element of x as x[0] and the first element of its x_unsqueeze counterpart as x_unsqueeze[0,0].

print(x[0]) #tensor([7, 9], dtype=torch.int32)
print(x_unsqueeze[0][0]) #tensor([7, 9], dtype=torch.int32)

Note that this operation does not make a copy of the tensor data. Instead, out uses the same underlying storage as x and only plays with the size and stride information at the tensor level. This is convenient because the operation is very cheap; but just as a heads-up.

print(x.storage().data_ptr() == x_unsqueeze.storage().data_ptr())
# share the same underlying data.

Related Post

Use Saved PyTorch model to predict single and multiple images.

What does view(-1) in PyTorch?

What does tensor.view() do in PyTorch?

How to reshape tensor in PyTorch?

How to use PyTorch gather function for indexing?

PyTorch changes image channel order between channel first and channel last

How to Indexing and Slicing PyTorch Tensor?

PyTorch Difference Between View and Reshape.