To create a PyTorch Module, we’ll need a Parameter, It is a subclass of
torch.Tensor. Parameter class “doesn’t add any functionality other than automatically calling requires_grad for us. It’s used only as a ‘marker’ to show what to include in the parameters. The parameter behaves just like a tensor, as we wanted.
import torch import torch.nn as nn my_param = nn.Parameter(torch.randn(1)) print(my_param)
One important behavior of
torch.nn.Module is registering parameters. Model weights are expressed as instances of
torch.nn.Parameter. It has a special behavior that when assigned as attributes of a Module, they are added to the list of that module’s parameters. These parameters may be accessed through the parameters() method on the Module class.
class TinyModel(torch.nn.Module): def __init__(self): super(TinyModel, self).__init__() self.linear1 = torch.nn.Linear(100, 200) self.activation = torch.nn.ReLU() self.linear2 = torch.nn.Linear(200, 10) self.softmax = torch.nn.Softmax() def forward(self, x): x = self.linear1(x) x = self.activation(x) x = self.linear2(x) x = self.softmax(x) return x tinymodel = TinyModel() print('\n\nModel params:') for param in tinymodel.parameters(): print(param)
Model parameters are learned and updated using SGD during the training process. However, sometimes there are other quantities that are part of a model’s “state” and should be saved as part of
Registering these “arguments” as the model’s buffer allows PyTorch to track them and save them like regular parameters, but prevents PyTorch from updating them using the SGD mechanism.
If you have parameters in your model, that should be saved and restored in the
state_dict, but not trained by the optimizer, you should register them as buffers. Buffers won’t be returned in
model.parameters(), so that the optimizer won’t have a chance to update them.
class MyModel(nn.Module): def __init__(self): super(MyModel, self).__init__() self.my_tensor = torch.randn(1) self.register_buffer('my_buffer', torch.randn(1)) self.my_param = nn.Parameter(torch.randn(1)) def forward(self, x): return x model = MyModel() print(model.my_tensor) #tensor([-0.5898]) print(model.state_dict()) #OrderedDict([('my_param', tensor([-1.7432])), ('my_buffer', tensor([1.6982]))]) model.cuda() print(model.my_tensor) #tensor([-0.5898]) print(model.state_dict()) #OrderedDict([('my_param', tensor([-1.7432], device='cuda:0')), ('my_buffer', tensor([1.6982], device='cuda:0'))])
One reason to register the tensor as a buffer is to be able to serialize the model and restore all internal states. Another one is that all buffers and parameters will be pushed to the device if called on the parent model.
As you can see,
model.my_tensor is still on the CPU, where is was created, while all parameters and buffers were pushed to the GPU after calling
Use of register_buffer
Batch normalization performs the normalization for each mini-batch, and back-propagate the gradients through the normalization parameters. Batch normalization adds two extra parameters per activation, and in doing so preserves the representation ability of the network.
An example of a buffer can be found in BatchNorm module where the
num_batches_tracked are registered as buffers and updated by accumulating statistics of data forwarded through the layer. This is in contrast to weight and bias parameters that learn an affine transformation of the data using regular SGD optimization.
m = nn.BatchNorm2d(5) input = torch.randn(1, 5, 50, 50) output = m(input)
To access the buffers in a specific layer, you can access them directly:
You can get all buffers via
model.named_buffers() same as with
buffers() returns the same buffers where the
named_buffers() returns the corresponding name for each buffer.
print("\nName Buffer") for name, param in model.named_buffers(): print( name, "[", type(name), "]", type(param), param.size())
register_buffer vs register_parameter
register_buffers is a fixed tensor and non-learnable parameter.
register_buffers does not require a gradient.
register_parameters is a learnable parameter and has a required gradient.