There are several optimization strategies that can assist convergence when models get complicated. PyTorch abstracts the optimization strategy away from user code. This saves us from the boilerplate busywork of having to update each and every parameter to our model ourselves. 

The torch module has a torch.optim submodule where we can find classes implementing different optimization algorithms. Every optimizer constructor takes a list of parameters like PyTorch tensors, as the first input. All parameters passed to the optimizer are retained inside the optimizer object so the optimizer can update their values.

In this tutorial, we discuss SGD(stochastic gradient descent) optimizer with weight decay. SGD optimizer itself is exactly a vanilla gradient descent. The term stochastic comes from the fact that the gradient is typically obtained by averaging over a random subset of all input samples, called a minibatch. 

SGD is working on small batches (mini-batches) of shuffled data. It helps convergence and prevents the optimization process from getting stuck in the local minima it encounters along the way. 

L2 regularization

Training a model involves two critical steps: optimization, when we need the loss to decrease on the training set; and generalization, when the model has to work not only on the training set but also on data it has not seen before, like the validation set. The mathematical tools aimed at easing these two steps are sometimes subsumed under the label regularization.

The first way to stabilize generalization is to add a regularization term to the loss. This term is crafted so that the weights of the model tend to be small on their own, limiting how much training makes them grow. In other words, it is a penalty on larger weight values. This makes the loss have a smoother topography, and there’s relatively less to gain from fitting individual samples.

The most popular regularization terms of this kind are L2 regularization, which is the sum of squares of all weights in the model, and L1 regularization, which is the sum of the absolute values of all weights in the model. Both of them are scaled by a (small) factor, which is a hyperparameter we set prior to training.

In PyTorch, we could implement regularization pretty easily by adding a term to the loss. After computing the loss, whatever the loss function is, we can iterate the parameters of the model, sum their respective square (for L2) or abs (for L1), and backpropagate:

for epoch in range(1, n_epochs + 1):
  loss_train = 0.0
  
  for imgs, labels in train_loader:
    imgs = imgs.to(device=device)
    labels = labels.to(device=device)
    outputs = model(imgs)
    loss = loss_fn(outputs, labels)
    l2_lambda = 0.001
    l2_norm = sum(p.pow(2.0).sum()
    
    for p in model.parameters())
        loss = loss + l2_lambda * l2_norm
    
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    loss_train += loss.item()

Weight Decay

Weight regularization provides an approach to reduce the overfitting of a deep learning neural network model on the training data and improve the performance of the model on new data, such as the holdout test set.

L2 regularization is also referred to as weight decay. The reason for this name is that thinking about SGD and backpropagation, the negative gradient of the L2 regularization term with respect to a parameter w_i is – 2 * lambda * w_i, where lambda is the aforementioned hyperparameter, simply named weight decay in PyTorch. 

So, adding L2 regularization to the loss function is equivalent to decreasing each weight by an amount proportional to its current value during the optimization step (hence, the name weight decay). 

optimizer = optim.SGD(model.parameters(), lr=1e-3,weight_decay = 0.5)

Generally, regularization only penalizes the weight ‘w’ parameter of the model, and the bias parameter does not penalize, but the weight_decay parameter of SGD is for all parameters in the network, including the weight w and bias b.

If you want to turn off weight decay for biases, you can use “parameter groups” to use different optimizer hyperparameters to optimize different sets of network parameters.

inputs = torch.randn(in_dim)
target = torch.tensor([1,2],dtype=torch.float32)

model = torch.nn.Linear(in_dim, out_dim, bias=True)
out = model(inputs)

optm = torch.optim.SGD ([
    {'params': model.weight, 'weight_decay': 0.5},
    {'params': model.bias, 'weight_decay': 0.0}
],lr=1e-3)

The optimizer will use different learning rate parameters for weight and bias, weight_ decay for weight is 0.5, and no weight decay (weight_decay = 0.0) for bias.

However, the SGD optimizer in PyTorch already has a weight_decay parameter that corresponds to 2 * lambda, and it directly performs weight decay during the update as described previously. It is fully equivalent to adding the L2 norm of weights to the loss, without the need for accumulating terms in the loss and involving autograd.

L2 regularization and weight decay regularization are equivalent to standard stochastic gradient descent (when rescaled by the learning rate). Due to this equivalence, L2 regularization is very frequently referred to as weight decay, including in popular deep-learning libraries.

Related Post