Tensors are the primary data structure for PyTorch. It stores and manipulates numerical information. It can be seen as a generalization of arrays and matrices, Specifically, tensors, as a generalization of 2D matrices and 1D arrays, which can store multidimensional data such as batches of three-channel images.

It is used to represent the inputs to models, the weight of layers within the models themselves, and the outputs of models. The standard linear algebra operations of transposition, addition, multiplication, inversion, etc., can all be run on tensors.

The PyTorch API provides us with many possible tensor operations, ranging from tensor arithmetic to tensor indexing. In this post, we will cover some of the more useful tensor operations to join tensors.

To join tensors you can use torch.cat to concatenate a sequence of tensors along a given dimension.torch.stack, another tensor joining op that is subtly different from torch.cat.

Cat

The cat function concatenates the given sequence tensors in the given dimension. The consequence is that a specific dimension changes size e.g. dim=0 then you are adding elements to the row which increases the dimensionality of the row space. torch.cat() can be best understood via examples.

import torch

a = torch.tensor([[1, 2],
                   [3, 4]])

b = torch.tensor([[5, 6],
                   [7, 8]])

c=torch.cat((a,b),dim=0)
print(c)
print(c.shape)

#output

tensor([[1, 2],
        [3, 4],
        [5, 6],
        [7, 8]])

torch.Size([4, 2])

The cat function extends a list in the given dimension e.g. adds more rows or columns. All tensors must either have the same shape except in the concatenating dimension or be empty.

Stack

The stack function serves the same role as append in lists. It concatenates the sequence of tensors along a new dimension. It doesn’t change the original vector space but instead adds a new index to the new tensor, so you retain the ability to get the original tensor you added to the list by indexing in the new dimension.

import torch

a = torch.tensor([[1, 2],
                   [3, 4]])

b = torch.tensor([[5, 6],
                   [7, 8]])

c=torch.stack((a,b),dim=0)
print(c)
print(c.shape)

#output
tensor([[[1, 2],
         [3, 4]],

        [[5, 6],
         [7, 8]]])
torch.Size([2, 2, 2])

The main difference between cat and stack is, cat concatenates the given sequence of tensors in the given dimension while stack concatenates a sequence of tensors along a new dimension.

Related Post