Model accuracy is not a reliable metric of performance, because it will yield misleading results if the validation data set is unbalanced. For example, if there were 90 cats and only 10 dogs in the validation data set and if the model predicts all the images as cats. The overall accuracy would be 90%.
The confusion matrix allows us to visualize the performance of the trained model. It makes it easy to see if the system is confusing two classes. It also summarizes the results of testing the model for further inspection. In this tutorial, we create a simple Convolutional Neural Network (CNN) to classify MNIST digits for visualization confusion matrix in TensorBord.
Download Dataset
We’re going to construct a simple neural network to classify images in the MNIST dataset. This dataset consists of 28×28 grayscale images of 10 digits(0-9) of 10 categories.
(train_images, train_labels), (test_images, test_labels) = datasets.mnist.load_data()
train_images = train_images.reshape((60000, 28, 28, 1))
test_images = test_images.reshape((10000, 28, 28, 1))
train_images, test_images = train_images / 255.0, test_images / 255.0
classes=[0,1,2,3,4,5,6,7,8,9]
Define Simple CNN Model
First, create a very simple model and compile it, setting up the optimizer and loss function and training it.
model = models.Sequential()
model.add(layers.Conv2D(32, (3, 3), activation='relu', input_shape=(28, 28, 1)))
model.add(layers.MaxPooling2D((2, 2)))
model.add(layers.Conv2D(64, (3, 3), activation='relu'))
model.add(layers.MaxPooling2D((2, 2)))
model.add(layers.Conv2D(64, (3, 3), activation='relu'))
model.add(layers.Flatten())
model.add(layers.Dense(64, activation='relu'))
model.add(layers.Dense(10, activation='softmax'))
model.compile(optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])
model.fit(x=train_images,
y=train_labels,
epochs=5,
validation_data=(test_images, test_labels))
The compile step also specifies that you want to log the accuracy of the classifier along the way.
Create a Confusion Matrix
You can use Tensorflow’s confusion matrix to create a confusion matrix.
y_pred=model.predict_classes(test_images)
con_mat = tf.math.confusion_matrix(labels=y_true, predictions=y_pred).numpy()
Normalization Confusion Matrix to the interpretation of which class is being misclassified.
con_mat_norm = np.around(con_mat.astype('float') / con_mat.sum(axis=1)[:, np.newaxis], decimals=2)
con_mat_df = pd.DataFrame(con_mat_norm,
index = classes,
columns = classes)

The diagonal elements represent the number of points for which the predicted label is equal to the true label, while off-diagonal elements are those that are mislabeled by the model.
figure = plt.figure(figsize=(8, 8))
sns.heatmap(con_mat_df, annot=True,cmap=plt.cm.Blues)
plt.tight_layout()
plt.ylabel('True label')
plt.xlabel('Predicted label')
plt.show()
We use matplotlib
to plot
confusion matrix and Seaborn
library to create a heatmap.

The confusion matrix shows that this model has some problems. “9”, “5”, and “2” are getting confused with each other. The model needs more work.
Plot Confusion Matrix in Tensorbord
Using the TensorFlow Image Summary API, you can easily view them in TensorBoard.Here’s what you’ll do:
- Create the Keras TensorBoard callback to log basic metrics
- Create a Keras LambdaCallback to log the confusion matrix at the end of every epoch
- Train the model using Model.fit(), making sure to pass both callbacks
You need some boilerplate code to convert the plot to a tensor, tf.summary.image()
expecting a rank-4 tensor containing (batch_size, height, width, and channels). Therefore, the tensors need to be reshaped.
file_writer = tf.summary.create_file_writer(logdir + '/cm')
def log_confusion_matrix(epoch, logs):
# Use the model to predict the values from the validation dataset.
test_pred = model1.predict_classes(test_images)
con_mat = tf.math.confusion_matrix(labels=test_labels, predictions=test_pred).numpy()
con_mat_norm = np.around(con_mat.astype('float') / con_mat.sum(axis=1)[:, np.newaxis], decimals=2)
con_mat_df = pd.DataFrame(con_mat_norm,
index = classes,
columns = classes)
figure = plt.figure(figsize=(8, 8))
sns.heatmap(con_mat_df, annot=True,cmap=plt.cm.Blues)
plt.tight_layout()
plt.ylabel('True label')
plt.xlabel('Predicted label')
buf = io.BytesIO()
plt.savefig(buf, format='png')
plt.close(figure)
buf.seek(0)
image = tf.image.decode_png(buf.getvalue(), channels=4)
image = tf.expand_dims(image, 0)
# Log the confusion matrix as an image summary.
with file_writer.as_default():
tf.summary.image("Confusion Matrix", image, step=epoch)
logdir='logs/images'
tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=logdir)
cm_callback = keras.callbacks.LambdaCallback(on_epoch_end=log_confusion_matrix)
You’re now ready to train the model and log this image and view it in TensorBoard.
model1.fit(
train_images,
train_labels,
epochs=5,
verbose=0,
callbacks=[tensorboard_callback, cm_callback],
validation_data=(test_images, test_labels))
The “Images” tab displays the image you just logged.

The image is scaled to a default size for easier viewing. If you want to view the unscaled original image, check “Show actual image size” at the upper left.
Related Post
- How to get the ROC curve and AUC for Keras model?
- Calculate Precision, Recall and F1 score for Keras model
- Micro and Macro Averages for imbalance multiclass classification
- Calculate F1 Macro in Keras
- PyTorch Confusion Matrix for multi-class image classification
- TensorBoard Callback of Keras with Google Colab