In order to train a PyTorch model you must read training data into memory, convert the data to PyTorch tensors, and serve the data in batches. This task is not trivial.

In this tutorial, We will use PyTorch Dataset and DataLoader interfaces to serve up training data from the folder. PyTorch Dataset object loads training data into memory, and a DataLoader object fetches data from a Dataset and serves the data up in batches.

Flowers Dataset

The flowers dataset consists of images of flowers with 5 possible class labels. Let’s download our training examples (it may take a while).

!wget http://download.tensorflow.org/example_images/flower_photos.tgz

Create Dataset from Folder

You must write code to create a Dataset that matches your data and problem scenario. No two Dataset implementations are exactly the same. On the other hand, a DataLoader object is used mostly the same no matter which Dataset object it’s associated with.

transform = transforms.Compose([
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])

dataset = datasets.ImageFolder('/content/flower_photos', transform=transform)

dataloader = torch.utils.data.DataLoader(dataset, batch_size=32, shuffle=True)

class_names = dataset.classes

The code fragment loads image data from a folder, and then Dataset passes it to a DataLoader constructor. The DataLoader object serves up batches of data with batch size = 32 training items in a random order.

In neural network terminology, an epoch is one pass through all source data. The DataLoader class is designed to be iterated using the enumerate() function, which returns a tuple with the current batch zero-based index value, and the actual batch of data. There is a tight coupling between a Dataset and its associated DataLoader.

def imshow(inp, title=None):
    """Display image for Tensor."""
    inp = inp.numpy().transpose((1, 2, 0))
    mean = np.array([0.485, 0.456, 0.406])
    std = np.array([0.229, 0.224, 0.225])
    inp = std * inp + mean
    inp = np.clip(inp, 0, 1)
    plt.imshow(inp)
    if title is not None:
        plt.title(title)
    plt.pause(0.001)  # pause a bit so that plots are updated

image,label = next(iter(dataloader))

It gives you a batch of size batch_size, and you can pick out a random example by directly indexing the batch.

imshow(image[0],class_names[label[0]])
PyTorch DataLoader Single Image

Related Post