In the real world, often our data has imbalanced classes e.g., 99.9% of observations are of class 1, and only 0.1% are class 2. In the presence of imbalanced classes, accuracy suffers from a paradox where a model is highly accurate but lacks predictive power

For example, imagine we are trying to predict the presence of a very rare cancer that occurs in 0.1% of the population. After training our model, we find the accuracy is at 95%. However, 99.9% of people do not have cancer. If we simply created a model that predicted that nobody had that form of cancer, our naive model would be 4.9% more accurate, but clearly is not able to predict anything. For this reason, we are often motivated to use other metrics like confusion matrix, precision, recall, and the F 1 score.

When we have balanced classes, accuracy is just like in binary classification, a simple and interpretable choice for an evaluation metric. Accuracy is the number of correct predictions divided by the number of observations and works just as well in multiclass as binary classification. However, when we have imbalanced classes, we should be inclined to use other evaluation metrics.

Confusion matrices are an easy, effective visualization of a classifier’s performance. One of the major benefits of confusion matrices is their interpretability. Each column of the matrix (often visualized as a heatmap) represents predicted classes, while every row shows actual classes. The end result is that every cell is one possible combination of predicted and actual classes. 

Predict the model on the test data

We have trained the cifar10 model over 5 epochs on the training dataset. Now we need to check if the network has learned anything at all. We will check this by predicting the class label that the neural network outputs, and checking it against the ground truth.

y_true = []
y_pred = []

for data in tqdm(testloader):
  images,labels=data[0].to(device),data[1]  
  y_true.extend(labels.numpy())

  outputs=model(images)

  _, predicted = torch.max(outputs, 1)
  y_pred.extend(predicted.cpu().numpy())

Although these metrics can be easily computed manually by comparing the actual and predicted class labels, scikit-learn provides a convenient confusion_matrix function that we can use, as follows:

cf_matrix = confusion_matrix(y_true, y_pred)

The array that was returned after executing the code provides us with information about the different types of errors the classifier made on the test dataset. 

Now, we can simply total up each type of result, substitute it into the template, and create a confusion matrix that will concisely summarize the results of testing the classifier:

class_names = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

# Create pandas dataframe
dataframe = pd.DataFrame(cf_matrix, index=class_names, columns=class_names)
PyTorch confusion matrix

We can map this information onto the confusion matrix using Matplotlib. The following confusion matrix plot, with the added labels, should make the results a little bit easier to interpret:

plt.figure(figsize=(8, 6))

# Create heatmap
sns.heatmap(dataframe, annot=True, cbar=None,cmap="YlGnBu",fmt="d")

plt.title("Confusion Matrix"), plt.tight_layout()

plt.ylabel("True Class"), 
plt.xlabel("Predicted Class")
plt.show()
PyTorch confusion matrix heatmap

All correct predictions are located in the diagonal of the table (highlighted in blue), so it is easy to visually inspect the table for prediction errors, as values outside the diagonal will represent them. 

This is probably best explained using an example. In the Visualizing a Classifier’s Performance solution, the top-left cell is the number of observations predicted to be ‘Plane’.However, the model does not do as well at predicting dog vs cat.

There are three things about confusion matrices. First, a perfect model will have values along the diagonal and zeros everywhere else. A bad model will look like the observation counts will be spread evenly around cells.

Second, a confusion matrix lets us see not only where the model was wrong, but also how it was wrong. That is, we can look at patterns of misclassification. For example, our model had an easy time differentiating ‘truck and dog’, but a much more difficult time classifying ‘dog and cat’. 

Finally, confusion matrices work with any number of classes (although if we had one million classes in our target vector, the confusion matrix visualization might be difficult to read).

Related Post

Run this code in Google Colab