Because of insufficient dataset size, training an entire convolutional network from scratch is not trivial. Instead, it is a common to use pre-train ConvNet on a very large dataset, and then use the ConvNet either as an initialization or a fixed feature extractor for the task of interest.

In this tutorial, you will learn how to remove the last layer from ResNet. ResNet helps to build a deeper neural network by utilizing skip connections to jump over some layers. There are different versions of ResNet, including ResNet-18, ResNet-34, ResNet-50, and so on. In this tutorial, we will use the ResNet-50 pertained model.

Replace Last Fully-Connected Layer

The below code snippet will load the ResNet-50 model that will be pre-trained on the ImageNet dataset . 

import torch
import torch.nn as nn

from torchvision import models

model = models.resnet50(weights='IMAGENET1K_V1')
print(model)
ResNet PyTorch Output Layer

As we are going to use this network in image classification with the different dataset, there will be different output class labels to be predicted by the network. For this purpose, we need to update the network because we can see in the above image, the final FC Linear() is having the 1000 nodes at the output layer. 

This must be change according to your dataset. For this purpose, we will reset final fully connected layer using the below lines of codes:

num_ftrs = model.fc.in_features

model.fc = nn.Linear(num_ftrs, 2)

# get features for input
x = torch.randn(1, 3, 224, 224)
out = model(x)
print(out.shape) #torch.Size([1, 2])

print(model)
PyTorch ResNet Transfer Learning

Using nn.identity()

Alternatively, you could also load the original model and replace the last layer with an nn.Identity(). While this approach might be simpler.


model = models.resnet50()
# replace last linar layer with nn.Identity
model.fc = nn.Identity(54, unused_argument1=0.1, unused_argument2=False)

print(model)

nn.Identity() is a placeholder identity operator that is argument-insensitive.The nn.Identity module will just return the input without any manipulation and can be used to e.g. replace other layers.

An often used use case for nn.Identity() would be to get the “features” of a pre-trained model instead of the class logits.
Here is an example:

# get features for input
x = torch.randn(1, 3, 224, 224)
out = model(x)
print(out.shape) #torch.Size([1, 2048])

As mentioned before, nn.Identity() will just return the input without any clone usage or manipulation of the input. The input and output would thus be the same, this is the “pass-through” layer. If you thus manipulate the input inplace, the output of nn.Identity() will also be changed.

Related Post

PyTorch Freeze Layer for fixed feature extractor in Transfer Learning

Feature extraction from an image using pre-trained PyTorch model

What is a feature map or activation map in convolutional neural networks?

Extract Intermediate Layer Output from PyTorch CNN model.

How to modify pre-train PyTorch model for Finetuning and Feature Extraction?