An input to a PyTorch convolutional model contains matrices with dimensions N1×N2 (image height and width in pixels). These N1×N2 matrices are called channels. Conventional implementations of the convolutional layer expect a rank-3 tensor representation as an input. For example, a three-dimensional array, n1 x n2 x cin , where Cin is the number of input channels.  

Let’s consider images as input to the first layer of a CNN. If the image is colored and uses the RGB color mode, then Cin = 3 (for the red, green, and blue color channels in RGB). However, if the image is in grayscale, then we have Cin = 1 because there is only one channel with the grayscale pixel intensity values. 

Load an image file

We can read images into NumPy arrays using the uint8-bit integer data type to reduce memory usage. Unsigned 8-bit integers take values in the range [0, 255], which are sufficient to store the pixel information in RGB images, which also take values in the same range.

PyTorch provides a module for loading images via torchvision. Let’s read an image (this example RGB image): 

import torch
from torchvision.io import read_image

import matplotlib.pyplot as plt

img = read_image('/content/cow.jpg')

print('Image shape:', img.shape) #Image shape: torch.Size([3, 536, 800])
print('Number of channels:', img.shape[0]) #Number of channels: 3
print('Image data type:', img.dtype) #Image data type: torch.uint8

Note that with torchvision, the input and output image tensors are in the format of Tensor[channels, image_height, image_width]. 

It is worth mentioning that usually when we read an image, the default dimension for the channels is the first dimension of the tensor array (or the second dimension considering the batch dimension).

This is called the NCHW format, where N stands for the number of images within the batch, C stands for channels, and H and W stand for height and width, respectively. 

Note that the Conv2D class assumes that inputs are in NCHW format by default. (Other tools, such as Matplotlib, and TensorFlow, use NHWC format.)

However, if you come across some data whose channels are placed at the last dimension, you would need to swap the axes in your data to move the channels to the first dimension.

We can use the tensor’s permute method with the old dimensions for each new dimension to get to an appropriate layout. Given an input tensor C × H × W as obtained previously, we get a proper layout by having channel 1 first and then channels 2 and 0: 

out = img.permute(1, 2, 0)
print('Image shape:', out.shape) #Image shape: torch.Size([536, 800, 3])

plt.imshow(out)
plt.show()

tensor.permute does not make a copy of the tensor data. Instead, out uses the same underlying storage as img and only plays with the size and stride information at the tensor level.

This is convenient because the operation is very cheap, but just as a heads-up: changing a pixel in img will lead to a change in out.

PyTorch change image order first last

Note how we have to use permute to change the order of the axes from C × H × W to H × W × C to match what Matplotlib expects.

Related Post

How to reshape tensor in PyTorch?

PyTorch Difference Between View and Reshape.

What does tensor.view() do in PyTorch?

What does tensor.view() do in PyTorch?

What does Unsqueeze do in PyTorch?