PyTorch executes everything as a “graph”. TensorBoard can visualize these model graphs so you can see what they look like.TensorBoard is TensorFlow’s built-in visualizer, which enables you to do a wide range of things, from visualizing your model structure to watching training progress.

In this post, we find out how to use TensorBoard to visualize your PyTorch model and also look at the structure of the model that we created.

To install TensorBoard for PyTorch, use the following command:

pip install tensorboard

Once you’ve installed TensorBoard, these enable you to log PyTorch models and metrics into a directory for visualization within the TensorBoard UI. Scalars, images, histograms, graphs, and embedding visualizations are all supported for PyTorch models.

Create Model

We’ll define a simple model architecture from that tutorial.

model = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=3, padding=1),
            nn.ReLU(),
            
            nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2, 2), # output: 64 x 16 x 16
            nn.BatchNorm2d(64),

            nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            
            nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2, 2), # output: 128 x 8 x 8
            nn.BatchNorm2d(128),

            nn.Flatten(), 
            nn.Linear(8*8*128, 1024),
            nn.ReLU(),

            nn.Linear(1024, 10))

Print PyTorch model summary will give you some idea about the different layers involved and their specifications.

PyTorch print model Summary

You will not get as detailed information about the model as in Keras’ model.summary(), PyTorch has a no model.summary() 

Torch-summary provides information complementary to what is provided by print(your_model) in PyTorch, similar to Tensorflow’s model.summary() API to view the visualization of the model.

from torchsummary import summary
........
summary(model,input_size=(3,32,32))
PyTorch Model Summary Using torchsummary

The summary must take the input size and batch size is set to -1 meaning any batch size we provide.

Write Model Summary

TensorBoard is a web interface that reads data from a file and displays it. To make this easy for us, PyTorch has a utility class called SummaryWriter. The SummaryWriter class is your main entry to log data for visualization by TensorBoard.

from torch.utils.tensorboard import SummaryWriter
......
writer=SummaryWriter('/content/logsdir')

The SummaryWriter class provides a high-level API to create an event file in a given directory and add summaries and events to it. The class updates the file contents asynchronously. This allows a training program to call methods to add data to the file directly from the training loop, without slowing down training.

 Running TensorBoard

When we run the TensorBoard command, we pass an argument that tells TensorBoard where the data is. To launch TensorBoard, we need to run the TensorBoard command from the terminal.

tensorboard --logdir = content/logsdir

This will launch a local server that will serve the TensorBoard UI from the data SummaryWriter wrote to disk. Navigating to https://localhost:6006 should show the following.

PyTorch TensorBoard

Here, we will be able to see our network graph. One of TensorBoard’s strengths is its ability to visualize complex model structures.

Let’s visualize the model we built. Go ahead and double-click on “Sequential” to see it expand, seeing a detailed view of the individual operations that make up the model.

PyTorch TensorBoard Graph

Double-clicking allows us to zoom out. Note that the graph is inverted; data flows from bottom to top, so it’s upside-down compared to the code. You can see that the graph closely matches the PyTorch model definition, with extra edges to other computation nodes.

We are able to expand each of these blocks by clicking the plus sign to see more detail. For example, if I expand the “Conv2d” block, we see that it is made up of a number of subcomponents. We can scroll to zoom in and out and click-and-drag to pan.

PyTorch Model Block

You can also see metadata by clicking on a node. This allows you to see inputs, outputs, shapes, and other details.

Related Post

Run this code in Google Colab