Deep learning models usually require a lot of data for training. In general, the more the data, the better the performance of the model. But acquiring massive amounts of data comes with its own challenges. Instead of spending days manually collecting data, we can make use of Image augmentation techniques. Image augmentation helps spruce up existing images without having to put in manual time-taking efforts.

In this tutorial, we will understand the concept of image augmentation, why it’s helpful, and what are the different image augmentation techniques. We’ll also implement these image augmentation techniques using torchvision.transforms.

Image Augmentation is the process of generating new images for the training CNN model. These new images are generated from the existing training images and hence we don’t have to do them manually.

How data augmentation perform using a transformer?

The transforms applied operations to your original images at every batch generation. So your training dataset is left unchanged, only the batch images are copied and transformed every iteration.

When we are performing data augmentation, we keep our original dataset and then add other versions of it (Flip, Rotation, Crop…etc).

There are multiple image augmentation techniques and we will discuss some of the common and most widely used ones. I will be using 8 Dog images to demonstrate image augmentation techniques. You can try other images as well as per your requirement.

import torch
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import glob
from torch.utils.data import DataLoader,Dataset
from PIL import Image
import numpy as np

img_list=glob.glob('DOG/*.jpeg')

class DogsDataset(Dataset):
  def __init__(self,image_list,transforms=None):
    self.image_list=image_list
    self.transforms=transforms
  def __len__(self):
    return len(self.image_list)
  def __getitem__(self,i):
    img=plt.imread(self.image_list[i])
    img=Image.fromarray(img).convert('RGB')
    img=np.array(img).astype(np.uint8)

    if self.transforms is not None:
      img=self.transforms(img)
    return torch.tensor(img,dtype=torch.float)

def show_img(img):
  plt.figure(figsize=(40,38))
  npimg=img.numpy()
  plt.imshow(np.transpose(npimg,(1,2,0)))
  plt.show()
Pytorch Image Augmentation

Image Rotation

Image rotation helps our model to become more robust to the changes in the orientation of objects. The information of the image remains the same, for example, A dog is a dog even if we see it from a different angle. Let’s see how we can rotate it. we will use the RandomRotation function of the torchvision.transforms to rotate the image.

transform=transforms.Compose([
                              transforms.ToPILImage(),
                              transforms.Resize((164,164)),
                              transforms.RandomRotation(50,expand=True),  
                              transforms.Resize((164,164)),
                              transforms.ToTensor(),
                              ])

dog_dataloader=DataLoader(DogsDataset(img_list,transform),batch_size=8,shuffle=True)

data=iter(dog_dataloader)
show_img(torchvision.utils.make_grid(data.next()))
Random Rotation Pytorch

If you have random rotations, some of the data points are returned as original, some are returned as rotated (e.g. 4 Rotated and 4 Originals). In other words, by one iteration through the dataset items, you get 8 data points(some rotated and some not).

Random Cropping

The differently cropped image is the most important aspect of image diversity. When your network is used by real users, the object in the image can be in a different position. Also, sometimes, the object can cover the entire image and yet will not be present totally in the image (i.e cropped at the edges of the object). The code shows the cropping of images randomly.

transform=transforms.Compose([
                              transforms.ToPILImage(),
                              transforms.RandomCrop((120,120)),        
                              transforms.ToTensor(),
                              ])

dog_dataloader=DataLoader(DogsDataset(img_list,transform),batch_size=8,shuffle=True)

data=iter(dog_dataloader)
show_img(torchvision.utils.make_grid(data.next()))
Random Crop PyTorch

Flipping Images

Your network will be trained on patches of images that are randomly flipped from the original dataset, and which are sometimes flipped with probability = 0.5. Let’s see how we can implement flipping.

transform=transforms.Compose([
                              transforms.ToPILImage(),
                              transforms.Resize((164,164)),
                              transforms.RandomVerticalFlip(0.4), 
                              transforms.RandomHorizontalFlip(0.4),        
                              transforms.ToTensor(),
                              ])

dog_dataloader=DataLoader(DogsDataset(img_list,transform),batch_size=8,shuffle=True)

data=iter(dog_dataloader)
show_img(torchvision.utils.make_grid(data.next()))
Flip Image Pytorch

This is how we can flip the image and make more generalized models that will learn the patterns of the original as well as the flipped images. 

Brightness, Contrast, Saturation, Hue

The quality of the images will not be the same from each source. Some images might be of very high quality while others might be just plain bad. In such scenarios, we can blur the image. This helps make our deep-learning model more robust. Transforms provide a class for randomly changing the brightness, contrast, and saturation of an image.

transform=transforms.Compose([
                              transforms.ToPILImage(),
                              transforms.Resize((164,164)),
                              transforms.ColorJitter(brightness=0.1, contrast=0.2, saturation=0, hue=0),
                              transforms.ToTensor(),
                              ])

dog_dataloader=DataLoader(DogsDataset(img_list,transform),batch_size=8,shuffle=True)

data=iter(dog_dataloader)
show_img(torchvision.utils.make_grid(data.next()))

Gaussian Noise to Images

Adding random noise to the images is also an image augmentation technique. Let’s see how we can do that. We will use a Gaussian filter for blurring the image. 

class AddGaussianNoise(object):
    def __init__(self, mean=0., std=1.):
        self.std = std
        self.mean = mean
        
    def __call__(self, tensor):
        return tensor + torch.randn(tensor.size()) * self.std + self.mean
    
    def __repr__(self):
        return self.__class__.__name__ + '(mean={0}, std={1})'.format(self.mean, self.std)

Transforms give you fine-grained control of the transformation pipeline. you can use a functional transform to build transform classes with custom behavior.

transform=transforms.Compose([
                              transforms.ToPILImage(),
                              transforms.Resize((164,164)),
                              transforms.ToTensor(),
                              AddGaussianNoise(0.1, 0.08)
                              ])

dog_dataloader=DataLoader(DogsDataset(img_list,transform),batch_size=8,shuffle=True)

data=iter(dog_dataloader)
show_img(torchvision.utils.make_grid(data.next()))

That allows our model to learn how to separate signal from noise in an image. I will take the standard deviation of the noise to be added as 0.08 (you can change this value as well). Just keep in mind that increasing this value will add more noise to the image and vice versa.

Gaussian Noise to Images PyTorch

Random Erasing

Randomly selects a rectangle region in an image and erases its pixels with random values. In this process, training images with various levels of occlusion are generated, which reduces the risk of over-fitting and makes the model robust to occlusion.

transform=transforms.Compose([
                              transforms.ToPILImage(),
                              transforms.Resize((164,164)),   
                              transforms.ToTensor(),
                              transforms.RandomErasing(),  
                              ])

dog_dataloader=DataLoader(DogsDataset(img_list,transform),batch_size=8,shuffle=True)

data=iter(dog_dataloader)
show_img(torchvision.utils.make_grid(data.next()))
Random Erasing Pytorch

These are some of the image augmentation techniques which help to make our deep learning model robust and generalizable. This also helps increase the size of the training set.

All transformations somehow change the image. They leave the original untouched, just returning a changed copy. Given the same input image, some methods will always apply the same changes(e.g., converting it to Tensor, resizing it to a fixed shape, etc.). Other methods will apply transformations with random parameters, returning different results each time (e.g., randomly cropping the images, randomly changing their brightness or saturation, etc.).So that means that upon every epoch you get a different version of the dataset, 

The purpose of data augmentation is to try to get an upper bound of the data distribution of unseen data.

Related Post