For any standard supervised learning task, we’ll split our data into training and validation sets. We want to ensure both sets represent the range of real-world input data. If either set is meaningfully different from our real-world use cases, it’s pretty likely that our model will behave differently than we expect.

We split up our data into a training set, a validation, and a test set. This enables us to make a fair evaluation of our model by directly measuring how well it generalizes on new data it has not yet seen.

Validation Set

At the end of each epoch, we want to measure how well our model is generalizing. To do this, we use an additional validation set,  A validation set to prevent overfitting during the training process. Split the data so that he could work on the part of it and keep an independent set for validation.

Full Training Dataset

We didn’t explicitly train our model on the validation set, we choose the epoch of training to use based on the model’s performance on the validation set. That’s a bit of a data leak, too. In fact, we should expect our real-world performance to be slightly worse than this, as it’s unlikely that whatever model performs best on our validation set will perform equally well on every other unseen set of data. For this reason, practitioners often split data into three sets:

  1. A training set.
  2. A validation set is used to determine which epoch of the evolution of the model to consider “best”
  3. A test set, used to actually predict performance for the model (as chosen by the validation set) on unseen, real-world data

Adding a third set would have led us to pull another nontrivial chunk of our training data, which would have been somewhat painful, given how badly we had to fight to overfit already.


In this tutorial, we use the weather dataset from Kaggle. This dataset contains 6862 images of different types of weather, it can be used to implement weather classification based on the photo. The pictures are divided into 11 classes: dew, fog/smog, frost, glaze, hail, lightning, rain, rainbow, rime, sandstorm, and snow.

Create Train, Valid, and Test sets

Next, we create the Train, Valid, and Test sets. Here we create separate lists of image paths for Train, Valid, and Test sets. These will be used in our Dataset class which will be defined for a custom dataset.

data_dir = '../content/dataset' 
class_names = os.listdir(data_dir)
num_class = len(class_names)

idx_to_class = {i:j for i, j in enumerate(class_names)}
class_to_idx = {value:key for key,value in idx_to_class.items()}

Shuffle the list before splitting else you won’t get all the classes in the three splits since these indices would be used by the Subset class to sample from the original dataset. Shuffling the elements of a tensor amounts to finding a permutation of its indices. The random_split function does exactly this:


This function split a dataset into non-overlapping new datasets of given lengths. If a list of fractions that sum up to 1 is given, the lengths will be computed automatically as floor(frac * len(dataset)) for each fraction provided.

train_idx,test_idx,val_idx=random_split(image_files, [0.8, 0.1, 0.1], generator=torch.Generator().manual_seed(42))

train_list=[image_files[i] for i in train_idx.indices]
test_list=[image_files[i] for i in test_idx.indices]
val_list=[image_files[i] for i in val_idx.indices]

print(len(train_idx.indices),len(test_idx.indices),len(val_idx.indices))   #5490 686 686
print(len(train_list),len(test_list),len(val_list))                        #5490 686 686

We just got an index list that we can use to build training and validation sets starting from the image file list.

The Dataset class

We create our WeatherDataset class by inheriting the Dataset class:

class WeatherDataset(Dataset):
    def __init__(self, image_paths, transform=None):
        self.image_paths = image_paths
        self.transform = transform
    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image_filepath = self.image_paths[idx]
        image = cv2.imread(image_filepath)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        label = image_filepath.split('/')[-2]
        label = class_to_idx[label]
        if self.transform is not None:
            image = self.transform(image)
        return image, label

__getitem__ expects an index. This is handled automatically by the dataloader which for every image in the batch runs __getitem__. In the code for __getitem__, we load the image at index “idx”, extract the label from the file path and then run it through our defined transform. The function returns the Tensor of the image array and its corresponding label.


You are usually creating separate training and validation Datasets and can thus pass the desired transformations to them.

    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])

    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])

train_dataset = WeatherDataset(train_list,train_transforms)
val_dataset = WeatherDataset(val_list,val_transforms)
test_dataset = WeatherDataset(test_list,val_transforms)

You can pass different transformations to each dataset, as they will be applied in the corresponding __getitem__ method.


The final step is DataLoader class is used to load data in batches for the model. This helps us process data in mini-batches that can fit within our GPU’s RAM.

train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True)

val_loader =DataLoader(val_dataset, batch_size=4, shuffle=True)

test_loader = DataLoader(test_dataset, batch_size=4, shuffle=False)

Let’s visualize some images after augmentation through the dataset.

def imshow(inp, title=None):
    """Imshow for Tensor."""
    inp = inp.numpy().transpose((1, 2, 0))
    mean = np.array([0.485, 0.456, 0.406])
    std = np.array([0.229, 0.224, 0.225])
    inp = std * inp + mean
    inp = np.clip(inp, 0, 1)
    if title is not None:
    plt.pause(0.001)  # pause a bit so that plots are updated

# Get a batch of training data
inputs, classes  = next(dataiter)

# Make a grid from batch
out = torchvision.utils.make_grid(inputs)

imshow(out, title=[class_names[x] for x in classes])
PyTorch Split Dataset Using Random Split

In the real world, large datasets are hard to come by, so it might seem like a waste to not use all of the data at our disposal during the training process. Consequently, it may be tempting to reuse training data for testing or cut corners while compiling test data. Be forewarned: if the test set isn’t well constructed, we won’t be able to draw any meaningful conclusions about our model.

Related Post

Run this code in Google Colab