The limitation of calculating loss on the training dataset is examples from each class are treated the same, which for imbalanced datasets means that the model is adapted a lot more for one class than another. Class weight allowing the model to pay more attention to examples from the minority class than the majority class in datasets with a severely skewed class distribution.

This tutorial demonstrates how to create a loss function for an imbalanced dataset in which minority class proportionally to its underrepresentation. You will use PyTorch to define the loss function and class weights to help the model learn from the imbalanced data.

First, generate a random dataset, then we can summarize the class distribution to confirm that the dataset was created as we expected.

x = torch.randn(20, 5) #The input is expected to contain raw, unnormalized scores for each class.

y = torch.randint(0, 5, (20,))

print(Counter(y.numpy())) #Counter({1: 5, 4: 5, 1: 4, 0: 4, 3: 2})

We can see that the dataset has an imbalanced class distribution.

Class Weight

To calculate the proper weights for each class, you can use the sklearn utility function shown in the example below.


print(class_weights) #([1.0000, 1.0000, 4.0000, 1.0000, 0.5714])

Class weight penalizes mistakes in samples of class[i] with class_weight[i] instead of 1. So higher class-weight means you want to put more emphasis on a class.

Loss Function

The CrossEntropyLoss() function that is used to train the PyTorch model takes an argument called “weight”. This argument allows you to define float values to the importance to apply to each class.

criterion_weighted = nn.CrossEntropyLoss(weight=class_weights,reduction='mean')
loss_weighted = criterion_weighted(x, y)

weight should be a 1D Tensor assigning weight to each of the classes.

reduction=’mean’: the loss will be normalized by the sum of the corresponding weights for each element. It is the default.

reduction=’none’: you would have to take care of the normalization yourself.

Usually, you increase the weight for minority classes, so that their loss also increases and forces the model to learn these samples.

Related Post