If you fine-tune a pre-trained model on a different dataset, you need to freeze some of the early layers and only update the later layers. In this tutorial, we go into the details of why you may want to freeze some layers and which ones should be frozen, and also I’ll show you how to do it in PyTorch. Let’s get started!
A neural network abstracts and transforms information in steps. In the initial layers, the features extracted are pretty generic and independent of the particular task. Later layers are much more specific to the particular task. So by freezing the initial stages, you can extract meaningful general features. You would unfreeze the last few layers, which would be tuned for your particular task.
It would not recommend unfreezing all layers if you have any new or untrained layers in your model. These untrained layers will have large gradients in the first few epochs, and your model will train as if initialized by random weights.
Create Pre-trained Model
First, we need a pre-trained model. The models subpackage in the
torchvision package provides definitions for many of the popular model architectures for image classification. You can construct these models by simply calling their constructor, which would initialize the model with random weights. To use the pre-trained models from the PyTorch Model, you can call the constructor with the
pretrained=True argument. Let’s load the pre-trained VGG16 model:
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets,models
This will start downloading the pre-trained model into your computer’s PyTorch cache folder.
Next, we will freeze the weights for all of the networks except the final fully connected layer. This last fully connected layer is replaced with a new one with random weights and only this layer is trained. The result of not freezing the pre-trained layers will be to destroy the information they contain during future training rounds.
Let’s Freeze the Layer to avoid destroying any of the information they contain during future training.
We have access to all the modules, layers, and their parameters, we can easily freeze them by setting the parameters’ requires_grad flag to False. This would prevent calculating the gradients for these parameters in the backward step which in turn prevents the optimizer from updating them.
for param in model_vgg16.parameters():
param.requires_grad = False
If you freeze all the layers except the final fully connected layer, you only need to backpropagate the gradient and update the weights of the final layers. In contrast to backpropagating and updating the weights of all the layers of the network, this means a huge decrease in computation time.
The reason it can save computation time is that your network would already be able to extract generic features from your dataset. The network will not have to learn to extract generic features from scratch.
Now that some of the parameters are frozen, the optimizer needs to be modified to only get the parameters with
requires_grad=True. By default, the optimizer is written like this –
optimizer_conv = optim.SGD(model_vgg16.parameters(), lr=0.001)
But, this will give you an error as this will try to update all the parameters of the model. We can do this by writing a Lambda function when constructing the optimizer
optimizer_conv = optim.SGD(filter(lambda p: p.requires_grad, model_vgg16.parameters()), lr=0.001)
When Freeze Layers
Here, most of the weight not updating and you are only optimizing a subset of the feature. If your dataset is similar to any subset of the imagenet dataset, this should not matter a lot, but, if it is different from imagenet, then freezing will mean a decrease in the accuracy. If you have enough computation time, unfreezing everything will allow you to optimize the whole feature space, allowing you to find better optima.
- Feature extraction from an image using pre-trained PyTorch model
- Extract Intermediate Layer Output from PyTorch CNN model.
- How to Visualize Feature Maps in Convolutional Neural Networks using PyTorch
- What is a feature map or activation map in convolutional neural networks?
- How to remove last layer from PyTorch Pre-train model?