PyTorch gather function is very difficult to understand but it is pretty useful. So what is the use of the gather function? Let’s understand using an example.

Basically, it gathers values along the axis(either from a row or column). Consider a 4×4 tensor where 4 batches and 4  features. Imagine the following situation: You like to select the 2nd feature from the 1st example, the 4th feature from the 2nd example, the 1st from the 3rd, and the 3rd from the 4th.

gather select random example

Here, we need to gather functions, torch.gathers() gather values along an axis specified by dimension. It is a multi-index selection function from a batch of examples. It requires three parameters:

  • input — input tensor, that we want to select elements from.
  • dim —  dimension (or axis) that we want to collect with.
  • index — are the indices to index input.
torch.gather(input=input,dim= 0,index=indx)

torch.gather() creates a new tensor from the input tensor by taking the values from each row or column along the input dimension. The index values are passed as tensors, specifying which value to take from each ‘row‘ or ‘column’.

For these small exercises with rows and columns, let’s start with a small 4×4 input.

import torch

input = torch.tensor([[1,  2,   3,  4], 
                                  [5,  6,   7,  8],
                                  [9, 10, 11, 12],
                                  [13,14, 15, 16]])

If dim=1 then select from the columns

Now let’s start indexing with columns (dim=1) and create an index list.

indx=torch.tensor([[0 ,1, 2, 3], 
                              [3, 2, 1, 0],
                              [2, 3, 0, 1],
                              [1, 2, 1, 0]])

Rule 1: Since our dimension is columns, and the input has 4 rows, the index list must be less than or equal to 4. We need a list for each row. Running torch.gather(input=input,dim=1, index=index) will give us:

#output
tensor([[ 1,  2,  3,  4],
            [ 8,  7,  6,  5],
            [11, 12,  9, 10],
            [14, 15, 14, 13]])

Each list within the index gives us the columns from inputs. The 1st list of the index ([0,1,2,3,]) is look at the 1st row of the source and takes the 0,1,2,3 column of that first row (it’s zero-indexed), which is [1,2,3,4] of input. 

The 2nd list of the index is ([3,2,1,0]) is look at the 2nd row of input and takes the 3,2,1,0 columns of that row, which is [8,7,6,5]. Jumping to the 4th list of the index ([1,2,1,0]), which is looking at the 4th and final row of the input, is asking us to take the 1,2,1,0th column of the 4th row which is [14,15,14,13].

Rule 2: Each list of your index has to be less than or equal to the input, but they may be as long as you like! For example:

index1=torch.tensor([[3, 2, 1, 0,0 ,1, 2, 3], 
                    [2, 1, 3, 1,2, 3, 0, 1]])

torch.gather(input=input,dim= 1,index=index1)

#output
tensor([[4, 3, 2, 1, 1, 2, 3, 4],
            [7, 6, 8, 6, 7, 8, 5, 6]])

The output will always have the same number of rows as the index and the number of columns will equal the length of each list in an index. 

With dim=1 our index has a number of rows less than or equal to the number of rows in the input, but the column could be as long, or short, as you like. Each value in the index needs to be less than the number of columns in the input.

If dim=0 then select from the rows

Switching to dim=0, we’ll now be using the rows as opposed to the columns. Using the same input, we now need an index where the width of each list equals the number of columns in the input. Why? Because each element in the list represents the row from input as we move column by column.

input = torch.tensor([[1, 2,  3,  4], 
                                  [5, 6,  7,  8],
                                  [9, 10, 11, 12],
                                  [13,14, 15, 16]])

indx=torch.tensor([[0 ,1, 2, 3], 
                               [3, 2, 1, 0],
                               [2, 3, 0, 1],
                               [1, 2, 1, 0]])
#output
tensor([[ 1,  6, 11, 16],
            [13, 10,  7,  4],
            [ 9, 14,  3,  8],
            [ 5, 10,  7,  4]])

Looking at the 1st list in the index ([0,1,2,3]), we can see that the 0th element (it’s zero-indexed) of the 1st column, which is [1], 1st element of the 2nd column, which is [6],2nd element of 3rd column, which is [11],3rd element of 3rd column, which is [16].

With dim=0 the output always has the same number of columns as the input, but the number of rows will equal the number of lists in an index. The width of each list in the index has to equal the number of columns in the input. Each value in the index needs to be less than the number of rows in the input.

With dim=0, each list in our index has to be the same length as the number of columns in the input, but we can now have as many lists as we like. Each value in the index needs to be less than the number of rows in the input.

input = torch.tensor([[1, 2,  3,  4], 
                                  [5, 6,  7,  8],
                                  [9, 10, 11, 12],
                                  [13,14, 15, 16]])

index1=torch.tensor([[0 ,1, 2, 3], 
                    [3, 2, 1, 0],
                    [2, 3, 0, 1],
                    [1, 2, 1, 0],
                    [2, 1, 3, 1],
                    [0, 2, 1, 3],
                    [1, 3, 2, 0]])

torch.gather(input=input,dim= 0,index=index1)

#output
tensor([[ 1,  6, 11, 16],
            [13, 10,  7,  4],
            [ 9, 14,  3,  8],
            [ 5, 10,  7,  4],
            [ 9,  6, 15,  8],
            [ 1, 10,  7, 16],
           [ 5, 14, 11,  4]])

Related Post