When dealing with convolutional networks, we have two ways to know what a model sees. First are the filters (weights)and second is the feature maps(activation map). In this tutorial, we will visualize feature maps in a convolutional neural network.

The idea of visualizing a feature map for a specific input image would be to understand what features of the input are detected or preserved in the feature maps. The expectation would be that the feature maps detect small or fine-grained detail. Let’s import all the libraries and modules first. 

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

from torch.utils.data import  DataLoader
from torchvision import models

import torchvision.transforms as transforms
import torchvision.datasets as dataset

import matplotlib.pyplot as plt
import numpy as np
import cv2 as cv

Pre-train VGG Model

We need a CNN model to visualize the feature map. Instead of fitting a model from scratch, we can use a pre-train state-of-the-art image classification model.

PyTorch provides many well-performing image classification models developed by different research groups for ImageNet. One example is the VGG-16 model that achieved top results in the 2014 competition. This is a good model to use for visualization because it has a simple uniform structure of serially ordered convolutional and pooling layers.

modelVGG = models.vgg16(pretrained=True)
print(modelVGG)

Running the example will load the model weights into memory and print a summary of the loaded model. 

It is deep with 16 learned layers, and it performed very well, meaning that the filters and resulting feature maps will capture useful features.

In order to explore the feature maps, we need input for the VGG16 model that can be used to create activations. We will use a simple image of a bee.

We need to load the bee image with the size expected by the model, in this case, 224×224. Next, the image object needs to be converted to a NumPy array of pixel data and expanded from a 3D array to a 4D array with the dimensions of [samples, rows, cols, channels], where we only have one sample.

img=cv.imread("/content/hymenoptera_data/val/bees/1297972485_33266a18d9.jpg")
img=cv.cvtColor(img,cv.COLOR_BGR2RGB)
plt.imshow(img)
plt.show()


transform = transforms.Compose([
    transforms.ToPILImage(),
    transforms.RandomResizedCrop(224),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
img=np.array(img)
img=transform(img)
img=img.unsqueeze(0)
print(img.size())

Accessing Convolutional Layers

We need to save all the convolutional layers from the VGG net. We will traverse through all these nestings to retrieve the convolutional layers. The following code shows how to retrieve all the convolutional layers.

no_of_layers=0
conv_layers=[]

model_children=list(modelVGG.children())

for child in model_children:
  if type(child)==nn.Conv2d:
    no_of_layers+=1
    conv_layers.append(child)
  elif type(child)==nn.Sequential:
    for layer in child.children():
      if type(layer)==nn.Conv2d:
        no_of_layers+=1
        conv_layers.append(layer)
print(no_of_layers)

First, we initialize a no_of_layers variable to keep track of the number of convolutional layers. Next, we going through all the layers of the VGG16 model.

results = [conv_layers[0](img)]
for i in range(1, len(conv_layers)):
    results.append(conv_layers[i](results[-1]))
outputs = results

Give the image as an input to the first convolutional layer after that, we will use a for loop to pass the last layer’s outputs to the next layer until we reach the last convolutional layer.

Visualizing the Feature Maps

We know that the number of feature maps (e.g. depth or the number of channels) in deeper layers is much more than 1, such as 64, 256, or 512. We plot only 16 two-dimensional images as a 4×4 square of images.

for num_layer in range(len(outputs)):
    plt.figure(figsize=(50, 10))
    layer_viz = outputs[num_layer][0, :, :, :]
    layer_viz = layer_viz.data
    print("Layer ",num_layer+1)
    for i, filter in enumerate(layer_viz):
        if i == 16: 
            break
        plt.subplot(2, 8, i + 1)
        plt.imshow(filter, cmap='gray')
        plt.axis("off")
    plt.show()
    plt.close()
PyTorch Feature Map

The feature maps are a result of applying filters to input images. Feature maps output by prior layers could provide insight into the internal representation that the model has of a specific input at a given point in the model.

Related Post

Run this code in Google colab