Modern convolutional neural networks have millions of parameters. Training them from scratch requires a lot of labeled training data and a lot of computing power. Transfer learning is a technique that shortcuts much of this by taking a piece of a model that has already been trained on a related task and reusing it in a new model.

The intuition behind transfer learning for image classification is that if a model is trained on a large and general enough dataset, this model will effectively serve as a generic model of the visual world. You can take advantage of these learned feature maps without having to start from scratch by training a large model on a large dataset.

Feature Extraction

You can use a pre-trained model to extract meaningful features from new samples. You simply add a new classifier, which will be trained from scratch, on top of the pre-trained model so that you can repurpose the feature maps learned previously for the dataset.

You do not need to re-train the entire model. The base convolutional network already contains features that are generically useful for classifying pictures. However, the final, classification part of the pre-trained model is specific to the original classification task, and subsequently specific to the set of classes on which the model was trained.

In feature extraction, we start with a pre-trained model and only update the final layer weights from which we derive predictions. It is called feature extraction because we use the pre-trained CNN as a fixed feature-extractor and only change the output layer.

This tutorial demonstrates how to build a PyTorch model for classifying five species of flowers by using a resnet18 pre-trained model from torchvision models, for image feature extraction, trained on the much larger and more general ImageNet dataset.


Let’s download our training examples from Kaggle and split them into train and test sets. The flowers dataset consists of images of flowers with 5 possible class labels.

os.environ['KAGGLE_USERNAME'] = "brijesh123" # username from the json file
os.environ['KAGGLE_KEY'] = "fd625c630b11dfcdskdml34r23c278425d5d6" 

!kaggle datasets download -d alxmamaev/flowers-recognition  

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

transfrom = transforms.Compose([
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])

Finally, notice that Resnet requires the input size to be (224,224).


This dataset contains 5 classes and is structured such that we can use the ImageFolder dataset, rather than writing our own custom dataset.



train_loader=DataLoader(train_set, batch_size=BATCH_SIZE,shuffle=True)
val_loader=DataLoader(val_set, batch_size=BATCH_SIZE)

Visualize Dataset

The flowers dataset consists of labeled images of flowers. Each example contains a JPEG flower image and the class label: what type of flower it is. Let’s display a few images together with their labels.

def showimages(imgs,actual_lbls,pred_lbls=None):
  fig = plt.figure(figsize=(21,12))

  for i,img in enumerate(imgs):
    fig.add_subplot(4,8, i+1)
    if pred_lbls!=None:
      title="prediction: {0}\nlabel:{1}".format(dataset.classes[y],dataset.classes[y_pre])
      title="Label: {0}".format(dataset.classes[y])

    img = img.numpy().transpose((1, 2, 0))
    mean = np.array([0.485, 0.456, 0.406])
    std = np.array([0.229, 0.224, 0.225])
    img = std * img + mean
    img = np.clip(img, 0, 1)

inputs, classes = next(iter(train_loader))

PyTorch Dataloader visualize images

Create Model

We will create the base model from the ResNet model. This is pre-trained on the ImageNet dataset, a large dataset consisting of 1.4M images and 1000 classes. ImageNet is a research training dataset with a wide variety of categories like jackfruit and syringe.

model = torchvision.models.resnet18(pretrained=True)

Freeze layers

In this step, you will freeze the convolutional base created from the previous step and use it as a feature extractor. Additionally, you add a classifier on top of it and train the top-level classifier.

for param in model.parameters():
    param.requires_grad = False

It is important to freeze the convolutional base before you compile and train the model. Freezing (by setting requires_grad == False ) prevents the weights in a given layer from being updated during training. 

When a trainable weight becomes non-trainable, its value is no longer updated during training.

Reshape the Layer

Here, we will freeze the weights for all of the networks except that of the final fully connected layer. This last fully connected layer is replaced with a new one with random weights and only this layer is trained.

num_ftrs = model.fc.in_features

model.fc = nn.Linear(num_ftrs, len(dataset.classes))

Reshape the final layer(s) to have the same number of outputs as the number of classes in the new dataset

Create Optimizer

The final step for feature extracting is to create an optimizer that only updates the desired parameters. We know that all parameters that have requires_grad=True should be optimized. Next, we make a list of such parameters and input this list to the SGD algorithm constructor.

loss_fn = nn.CrossEntropyLoss()

params_to_update = []
  for name,param in model.named_parameters():
      if param.requires_grad == True:

optimizer_ft = optim.SGD(params_to_update, lr=0.001, momentum=0.9)

Train Model

Now, let’s write a general function to train a model.

losses = {'train':[], 'val':[]}
accuracies = {'train':[], 'val':[]}
for epoch in range(1,epochs+1): 

plt.figure(figsize=(8, 8))
plt.subplot(2, 1, 1)
plt.plot(accuracies['train'], label='Training Accuracy')
plt.plot(accuracies['val'], label='Validation Accuracy')
plt.legend(loc='lower right')
# plt.ylim([min(plt.ylim()),1])
plt.title('Training and Validation Accuracy')

plt.subplot(2, 1, 2)
plt.plot(losses['train'], label='Training Loss')
plt.plot(losses['val'], label='Validation Loss')
plt.legend(loc='upper right')
plt.ylabel('Cross Entropy')
plt.title('Training and Validation Loss')

The train() function handles the training and validation of a given model. As input, it takes a PyTorch model, a dictionary of data loader, a loss function, an optimizer, a specified number of epochs to train and validate.

PyTorch Plot Accuracy and Loss

Predict Images

def predict_images(model,images,actual_label):
  with torch.no_grad():
    inputs =
    outputs = model(inputs)
    _, preds = torch.max(outputs, 1)

images, classes = next(iter(val_loader))

PyTorch Predict Images

When working with a small dataset, it is a common practice to take advantage of features learned by a model trained on a larger dataset in the same domain. This is done by instantiating the pre-trained model and adding a fully-connected classifier on top. The pre-trained model is “frozen” and only the weights of the classifier get updated during training. In this case, the convolutional base extracted all the features associated with each image and you just trained a classifier that determines the image class given that set of extracted features.

Run this code in Google colab