Image normalization and augmentation improve the performance and the generalization of the model. In this post, I am going to normalize images and make a list of the common image augmentation techniques to increase the size and diversity of images present in the PyTorch 2.0 DataPipe. We will explore simple transformations, like rotation, cropping, and Gaussian blur.

DataSet

The “Lions or Cheetahs” dataset is a collection of images downloaded from the Open Images Dataset V6, containing photographs of both lions and cheetahs. The dataset contains a total of 200 images. The images have been labeled as either “lion” or “cheetah” and are stored in separate directories within the dataset.

import os

os.environ['KAGGLE_USERNAME'] = "brijesh123" # username from the json file
os.environ['KAGGLE_KEY'] = "51f88a3b698e19ab0061d30f259e6df8" 

!kaggle datasets download -d mikoajfish99/lions-or-cheetahs-image-classification

PyTorch 2.0 DataPipe

It is composable Iterable-style and Map-style building blocks, that work well out of the box with the DataLoader2. These built-in DataPipes have the necessary functionalities, namely loading files, parsing, caching, transforming, filtering, and many more utilities.

import torch
from torchdata.datapipes.iter import IterableWrapper

import PIL
import numpy as np

from PIL import Image
from io import BytesIO

import torchvision.transforms as transforms

from torchdata.datapipes.iter import StreamReader,FileOpener,FileLister
from torchdata.dataloader2 import DataLoader2, MultiProcessingReadingService

import matplotlib.pyplot as plt

datapipe = FileLister(root=path, recursive=True,masks='*.jpg')
datapipe = FileOpener(datapipe, mode='b')
datapipe = datapipe.shuffle()
datapipe = StreamReader(datapipe)

You can see in this example how DataPipes can be easily chained together to compose graphs of transformations that reproduce sophisticated data pipelines, with the streamed operation.

DataPipe map function

Most transformations accept both PIL images and tensor images, although some transformations are PIL-only and some are tensor-only.

def load_image(datapipe):
    path, stream = datapipe
    label = 0 if 'Cheetahs' in path else 1
    img = Image.open(BytesIO(stream))
    return {'image':img, 'label':label}

datapipe = datapipe.map(fn=load_image)

Normalize Image

The normalization can constitute an effective way to speed up the computations in the model based on neural network architecture and learn faster. There are two steps to normalize the images:

  • We subtract the channel mean from each input channel
  • We divide it by the channel standard deviation.

Augmenting data with PyTorch is very straightforward. We can use the transforms provided in torchvision: torchvision.transforms.To compose several transforms together, we use torchvision.transforms.Compose and pass the transforms as a list. The transforms are applied following the list order.

data_transform = transforms.Compose(
    [
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ]
)

#Applies a transformer over each item from the source DataPipe. 
datapipe = datapipe.map(fn=data_transform,input_col='image')

For training, we should probably also add transforms.ToTensor to convert the images to a PyTorch Tensor and transforms.Normalize to normalize the images according to the network that you will train.input_col – Index or indices of data to which transform is applied,

Image Augmentation

This section includes the different transformations available in the torchvision.transforms module.

data_augment = transforms.Compose(
    [
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.GaussianBlur(kernel_size=(51, 91),sigma=4),
    ]
)

datapipe = datapipe.map(fn=data_augment,input_col='image')
datapipe = datapipe.batch(8)

DataLoader2

On top of DataPipes, this library provides a new DataLoader2 that allows the execution of these data pipelines in various settings and execution backends ReadingService.

rs = MultiProcessingReadingService(num_workers=1)
dl = DataLoader2(datapipe, reading_service=rs)


dp=next(iter(dl))

figure = plt.figure(figsize=(8, 4))
cols, rows = 4, 2

for i in range(0, cols * rows):
  
    label = dp[i]['label']
    img = dp[i]['image']

    inp = img.numpy().transpose((1, 2, 0))
    mean = np.array([0.485, 0.456, 0.406])
    std = np.array([0.229, 0.224, 0.225])
    inp = std * inp + mean
    inp = np.clip(inp, 0, 1)
    
    figure.add_subplot(rows, cols, i+1)
    
    plt.title(label)
    plt.axis("off")
    plt.imshow(inp)
plt.show()
PyTorch Transformer DataPipe Using Map

Related Post

Run this code in Google Colab