Natural images are messy, and as a result, there are a number of preprocessing operations that we can utilize in order to make training slightly easier. Neural networks usually work with floating-point tensors as their input. Neural networks exhibit the best training performance when the input data ranges roughly from 0 to 1, or from –1 to 1.
We’ll cast a tensor to a floating point and normalize the values of the pixels. Casting to floating-point is easy, but normalization is trickier, as it depends on what range of the input we decide should lie between 0 and 1 (or -1 and 1).
One technique that is supported out of the box in PyTorch is image whitening. The basic idea behind whitening is to zero-center every pixel in an image by subtracting out the mean of the dataset and normalizing it to unit 1 variance. This helps us correct for potential differences in dynamic range between images.
One possibility is to just divide the values of the pixels by 255 (the maximum representable number in 8-bit unsigned):
image = image.astype(float) / 255
Fortunately, PyTorch offers a package called Torchvision that includes many commonly used transforms for image processing. In PyTorch, we can achieve this using the Normalize Transform.
Transforms are really handy because we can chain them using transforms.Compose(), and they can handle normalization and data augmentation transparently, directly in the data loader.
data_transforms = transforms.Compose([ transforms.RandomResizedCrop(224), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize([0.51 0.47 0.35],[0.28 0.25 0.28])]
The magic numbers for mean, [0.51 0.47 0.35], and std, [0.28 0.25 0.28], were computed over the entire image dataset, and this technique is called dataset normalization.
Calculate mean and std for the PyTorch image dataset
We can also expand our dataset artificially by randomly cropping the image, flipping the image, modifying saturation, modifying brightness, etc.
Applying these transformations helps us build networks that are robust to the different kinds of variations that are present in natural images, and make predictions with high fidelity in spite of potential distortions.
Normalize PyTorch batch of tensors between 0 and 1 using scikit-learn MinMaxScaler
How to Scale data into the 0-1 range using Min-Max Normalization.
How to Normalize(Scale, Standardize) Pandas DataFrame columns using Scikit-Learn?
How to normalize, mean subtraction, standard deviation, and zero center image
Pytorch Image Augmentation using torchvision transforms.
Image Normalization and Augmentation in DataPipe PyTorch 2.0