Model accuracy is not a reliable metric of performance, because it will yield misleading results if the validation data set is unbalanced. For example, if there were 90 cats and only 10 dogs in the validation data set and if the model predicts all the images as cats. The overall accuracy would be 90%.

The confusion matrix allows us to visualize the performance of the trained model. It makes it easy to see if the system is confusing two classes. It also summarizes the results of testing the model for further inspection. In this tutorial, we create a simple Convolutional Neural Network (CNN) to classify MNIST digits for visualization confusion matrix in TensorBord.

Download Dataset

We’re going to construct a simple neural network to classify images in the MNIST dataset. This dataset consists of  28×28 grayscale images of 10 digits(0-9) of 10 categories.

(train_images, train_labels), (test_images, test_labels) = datasets.mnist.load_data()

train_images = train_images.reshape((60000, 28, 28, 1))
test_images = test_images.reshape((10000, 28, 28, 1))

train_images, test_images = train_images / 255.0, test_images / 255.0

classes=[0,1,2,3,4,5,6,7,8,9]

Define Simple CNN Model

First, create a very simple model and compile it, setting up the optimizer and loss function and training it.

model = models.Sequential()
model.add(layers.Conv2D(32, (3, 3), activation='relu', input_shape=(28, 28, 1)))
model.add(layers.MaxPooling2D((2, 2)))
model.add(layers.Conv2D(64, (3, 3), activation='relu'))
model.add(layers.MaxPooling2D((2, 2)))
model.add(layers.Conv2D(64, (3, 3), activation='relu'))

model.add(layers.Flatten())
model.add(layers.Dense(64, activation='relu'))
model.add(layers.Dense(10, activation='softmax'))

model.compile(optimizer='adam',
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])

model.fit(x=train_images, 
            y=train_labels, 
            epochs=5, 
            validation_data=(test_images, test_labels))

The compile step also specifies that you want to log the accuracy of the classifier along the way.

Create a Confusion Matrix

You can use Tensorflow’s confusion matrix to create a confusion matrix.

y_pred=model.predict_classes(test_images)
con_mat = tf.math.confusion_matrix(labels=y_true, predictions=y_pred).numpy()

Normalization Confusion Matrix to the interpretation of which class is being misclassified.

con_mat_norm = np.around(con_mat.astype('float') / con_mat.sum(axis=1)[:, np.newaxis], decimals=2)

con_mat_df = pd.DataFrame(con_mat_norm,
                     index = classes, 
                     columns = classes)
Keras Confusion matrix

The diagonal elements represent the number of points for which the predicted label is equal to the true label, while off-diagonal elements are those that are mislabeled by the model.

figure = plt.figure(figsize=(8, 8))
sns.heatmap(con_mat_df, annot=True,cmap=plt.cm.Blues)
plt.tight_layout()
plt.ylabel('True label')
plt.xlabel('Predicted label')
plt.show()

We use matplotlib to plot confusion matrix and Seaborn library to create a heatmap.

Plot Keras Confusion matrix

The confusion matrix shows that this model has some problems. “9”, “5”, and “2” are getting confused with each other. The model needs more work.

Plot Confusion Matrix in Tensorbord

Using the TensorFlow Image Summary API, you can easily view them in TensorBoard.Here’s what you’ll do:

  1. Create the Keras TensorBoard callback to log basic metrics
  2. Create a Keras LambdaCallback to log the confusion matrix at the end of every epoch
  3. Train the model using Model.fit(), making sure to pass both callbacks

You need some boilerplate code to convert the plot to a tensor, tf.summary.image() expecting a rank-4 tensor containing (batch_size, height, width, and channels). Therefore, the tensors need to be reshaped.

file_writer = tf.summary.create_file_writer(logdir + '/cm')

def log_confusion_matrix(epoch, logs):
  # Use the model to predict the values from the validation dataset.
  test_pred = model1.predict_classes(test_images)

  con_mat = tf.math.confusion_matrix(labels=test_labels, predictions=test_pred).numpy()
  con_mat_norm = np.around(con_mat.astype('float') / con_mat.sum(axis=1)[:, np.newaxis], decimals=2)

  con_mat_df = pd.DataFrame(con_mat_norm,
                     index = classes, 
                     columns = classes)

  figure = plt.figure(figsize=(8, 8))
  sns.heatmap(con_mat_df, annot=True,cmap=plt.cm.Blues)
  plt.tight_layout()
  plt.ylabel('True label')
  plt.xlabel('Predicted label')
  
  buf = io.BytesIO()
  plt.savefig(buf, format='png')

  plt.close(figure)
  buf.seek(0)
  image = tf.image.decode_png(buf.getvalue(), channels=4)

  image = tf.expand_dims(image, 0)
  
  # Log the confusion matrix as an image summary.
  with file_writer.as_default():
    tf.summary.image("Confusion Matrix", image, step=epoch)

    
logdir='logs/images'

tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=logdir)

cm_callback = keras.callbacks.LambdaCallback(on_epoch_end=log_confusion_matrix)

You’re now ready to train the model and log this image and view it in TensorBoard.

model1.fit(
    train_images,
    train_labels,
    epochs=5,
    verbose=0, 
    callbacks=[tensorboard_callback, cm_callback],
    validation_data=(test_images, test_labels))

The “Images” tab displays the image you just logged.

TensorBord Confusion Matrix

The image is scaled to a default size for easier viewing. If you want to view the unscaled original image, check “Show actual image size” at the upper left.

Related Post

Run this code in Google Colab