Consider you have a trained model named modelA and you want to copy its weights and biases into another model named modelB. This is typical when you want to initialize weights in a deep-learning network with weights from a pre-trained model. 

For sake of example, we will create a neural network and random data tensor for training.

input=torch.randn(8,10)
target=torch.randn(8,4)

class MyModel(nn.Module):
  def __init__(self):
    super().__init__()
    self.layer1=nn.Linear(10,4) 
    self.act1=nn.Sigmoid()
  
  def forward(self,x):
    x=self.layer1(x)
    x=self.act1(x)
    return x

We create an optimizer, in this case, Adam with a learning rate of 0.01 and register all the parameters of the model in the optimizer.

modelA=MyModel()

critertion=nn.BCELoss()
optimizer=torch.optim.Adam(modelA.parameters(),lr=0.001)

Let’s take a look at a single training step. We run the input data through the model through each of its layers to make a prediction. This is the forward pass.

result=modelA(input)

loss=critertion(result,target)
optimizer.zero_grad()
loss.backward()
optimizer.step()

We use the model’s prediction and the corresponding label to calculate the error (loss). This step is to backpropagate this error through the network. Backward propagation is kicked off when we call .backward() on the error tensor. Finally, we call .step() to initiate gradient descent. The optimizer adjusts each parameter by its gradient stored in .grad.

Assignment operations

The most straightforward approach to the copy is through normal assignment operations.

modelB=modelA
id(modelA)==id(modelB) #True

The PyTorch model is mutable if we change any of the two models this action will have a direct impact on the other model too, as they both point to the same object reference in memory.

Therefore normal assignment operations are typically used when we have to deal with immutable object types.

When it comes to Module, there is no clone method available so you can either use copy.deepcopy or create a new instance of the model and just copy the parameters using state_dict() and load_state_dict().

If you want to use the same state_dict in two independent models, you could use deepcopy or initialize a second model and load the state_dict again.

Deep Copy

Python comes with a module called copy that offers certain copy functionality. A deep copy will take a copy of the original object and will then recursively take a copy of the inner objects. The change in any of the models won’t affect the corresponding model.

import copy

modelB=copy.deepcopy(modelA)

In PyTorch, the learnable parameters (i.e. weights and biases) of a torch.nn.Module model are contained in the model’s parameters (accessed with model.parameters()). 

# Check params
for p1, p2 in zip(modelA.parameters(), modelB.parameters()):
    print(torch.equal(p1, p2))

You can see that the original and copied objects are essentially the same.

The deepcopy will recursively copy every member of an object, so it copies everything. It makes a deep copy of the original tensor meaning it creates a new tensor instance with a new memory allocation to the tensor data. The history will not be copied, as you cannot call copy.deepcopy on a non-leaf tensor.

Therefore, a deep copy is more suitable when we have to deal with the PyTorch model and want to ensure that a change in any of the models won’t affect others.

state_dict and load_state_dict

A state_dict is simply a Python dictionary object that maps each layer to its parameter tensor. They can be easily saved, updated, altered, and restored, adding a great deal of modularity to PyTorch models and optimizers.

sd=modelA.state_dict()

modelB=MyModel()

modelB.load_state_dict(sd)

Please note that load_state_dict only copies parameters and buffers.

Optimzer

If using deepcopy or STATE_DICT, the optimizer does not work. You need to reinitialize the optimizer using the new copied model and then you can copy the optimizer inner values from one to the other.

Optimizer objects (torch.optim) also have a state_dict, which contains information about the optimizer’s state, as well as the hyperparameters used. 

Related Post