Deep neural networks require years of experience to effectively choose optimal hyper-parameters, regularization, and network architecture, which are all tightly coupled. Setting the hyper-parameters, including designing the network architecture, requires expertise and extensive trial and error and is based more on serendipity than science. 

Choosing learning rate, momentum, and weight decay hyper-parameters well will improve the network’s performance. The conventional method is to perform a grid or a random search, which can be computationally expensive and time-consuming. In addition, the effects of these hyper-parameters are tightly coupled with each other, the data, and the architecture.

Learning Rate

The learning rate is simply the most important hyperparameter in a neural network. It can be found in optimization algorithms such as RMSprop, Adam, Gradient descent, etc.

optim = tf.keras.optimizers.SGD(learning_rate=1e-03, momentum=0.9)

During the optimization, the algorithm needs to take a series of tiny steps to descend the error mountain in order to minimize the error. Each tiny step has two properties: direction and size.

The direction of the step is determined by the derivative. To get the direction, we should calculate the derivative of the error with respect to the weight values.

The step size is determined by the learning rate. It determines how fast or slow the optimizer descends the error curve. With a large learning rate, the optimizer takes big steps to descend the error mountain. With a small learning rate, the optimizer takes small steps to descend the error mountain.

Weight Decay

When using a fixed learning rate, you can change the learning rate value only after training. The most effective way of using the learning rate is to decrease its value during training. This type of learning rate is known as a dynamic learning rate whose value is decreased over time. A dynamic learning rate will allow the model to properly converge at the optimal point where the error is minimum.

There are mainly two approaches to decrease the learning rate during training, Adaptive learning rates, and Learning rate decay.

In this post, we will focus on learning rate decay for Adam optimizers. Keras Adam optimizers ship with the standard learning rate decay which is controlled by the decay parameter. The standard learning rate decay has not been activated by default. So, we need to activate it by setting a proper value in the decay parameter.

Adam Optimizer

Adam optimization is a stochastic gradient descent method that is based on the adaptive estimation of first-order and second-order moments. It is computationally efficient, little memory requirement, invariant to diagonal rescaling of gradients, and is well suited for problems that are large in terms of data/parameters”.

Adam is robust and well-suited to a wide range of non-convex optimization problems in the field of machine learning those were the days.

Few research articles used it to train their models, and new studies began to clearly discourage to apply it and showed in several experiments that plain SGD with momentum was performing better. Adam seemed to get a new life by Ilya Loshchilov and Frank Hutter pointed out in their paper that the way weight decay is implemented in Adam in every library seems to be wrong, and proposed a simple way to fix it which they call AdamW.

AdamW Optimizer

State-of-the-art results for popular image classification datasets, such as CIFAR-10 and CIFAR-100 are still obtained by applying SGD with momentum. Adaptive gradient methods do not generalize as well as SGD with momentum when tested on a diverse set of deep learning tasks, such as image classification, character-level language modeling, and constituency parsing.

The main contribution of this paper is to improve regularization in Adam by decoupling the weight decay from the gradient-based update. In a comprehensive analysis, Adam generalizes substantially better with decoupled weight decay than with L2 regularization.

In this tutorial, we investigate whether it is better to use L2 regularization with SGD or AdamW to train deep neural networks. We will also go over how to implement these using Keras.

Dataset

This tutorial shows how to classify images of flowers using a tf.keras.Sequential model and load data using tensorflow_datasets.

train_ds = tfds.load("tf_flowers", 
                     split="train[:80%]",
                     as_supervised=True, 
                     with_info=False
                    )

valid_ds = tfds.load("tf_flowers", 
                     split="train[80%:]",
                     as_supervised=True, 
                     with_info=False)

This tutorial uses a dataset of about 3,700 photos of flowers. The dataset contains five sub-directories, one per class:

flower_photo/
  daisy/
  dandelion/
  roses/
  sunflowers/
  tulips/

Keras model

The Keras Sequential model consists of two convolution blocks with a max pooling layer in each of them. There’s a fully-connected layer with 128 units on top of it that is activated by a ReLU. This model has not been tuned for high accuracy; the goal of this tutorial is to show a standard approach.

def get_model(reg=None):
    model = tf.keras.Sequential([
        tf.keras.Input(shape=(IMAGE_SIZE, IMAGE_SIZE, 3)),
        tf.keras.layers.Conv2D(64, 3, kernel_regularizer=l2_reg, padding="same"),
        tf.keras.layers.MaxPooling2D(2),
        tf.keras.layers.ReLU(),
        tf.keras.layers.Dropout(0.2,),
        
        tf.keras.layers.Conv2D(64, 3, kernel_regularizer=l2_reg, padding="same"),
        tf.keras.layers.MaxPooling2D(2),
        tf.keras.layers.ReLU(),
        tf.keras.layers.Dropout(0.2,),
        
        tf.keras.layers.Flatten(),
        tf.keras.layers.Dense(128, activation="relu", kernel_regularizer=l2_reg, ),
         
        tf.keras.layers.Dense(5, kernel_regularizer=l2_reg),
    ])
    return model

Train model with SGD and L2 Regularization

SGD with momentum is a method that helps accelerate gradient vectors in the right directions, thus leading to faster converging. 

l2_reg = tf.keras.regularizers.l2(l2=0.001)
model=get_model(l2_reg)

sgd_optim = tf.keras.optimizers.SGD(learning_rate=1e-03, momentum=0.9)
loss_fn   = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
metrics   = tf.keras.metrics.SparseCategoricalAccuracy(name="accuracy")

model.compile(optimizer=sgd_optim, loss=loss_fn, metrics=metrics)


l2_sgd_hist = model.fit(train_ds, validation_data=valid_ds, epochs=7)

Train model with AdamW 

To use AdamW optimizer we need to use the class : tf.keras.optimizers.AdamW

model=get_model()

adamw_optim = tf.keras.optimizers.AdamW(weight_decay=0.001, learning_rate=3e-04)
loss_fn   = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
metrics   = tf.keras.metrics.SparseCategoricalAccuracy(name="accuracy")

model.compile(optimizer=adamw_optim, loss=loss_fn, metrics=metrics)

adamw_hist = model.fit(train_ds, validation_data=valid_ds, epochs=7)

This is an implementation of the AdamW optimizer described in Decoupled Weight Decay Regularization by Loshch ilov & Hutter.

Visualize training results

Create plots of the loss and accuracy on the training and validation sets:

plt.style.use("ggplot")
plt.figure(figsize=(10,6))

plt.title("AdamW vs SGD with L2 Losses")
plt.plot(l2_sgd_hist.history["loss"],  label="Training Loss sgd with l2")
plt.plot(adamw_hist.history["loss"],   label="Training Loss adamW")

plt.plot(l2_sgd_hist.history["val_loss"],  label="Validation losses sgd with l2", linestyle='dashed')
plt.plot(adamw_hist.history["val_loss"],   label="Validation losses adamW", linestyle='dashed')

plt.xlabel("# epochs")
plt.ylabel("loss")
plt.legend();
AdamW vs SGD with L2 Losses
plt.style.use("ggplot")
plt.figure(figsize=(10,6))

plt.title("AdamW vs SGD with L2 Accuracy")
plt.plot(l2_sgd_hist.history["accuracy"],  label="Training accuracy sgd with l2")
plt.plot(adamw_hist.history["accuracy"],   label="Training accuracy adamW")

plt.plot(l2_sgd_hist.history["val_accuracy"],  label="Validation accuracy sgd with l2", linestyle='dashed')
plt.plot(adamw_hist.history["val_accuracy"],   label="Validation accuracy adamW", linestyle='dashed')

plt.xlabel("# epochs")
plt.ylabel("accuracy")
plt.legend();
AdamW vs SGD with L2 Accuracy

L2 regularization and weight decay are not identical. The two techniques can be made equivalent for SGD by a reparameterization of the weight decay factor based on the learning rate, this is not the case for Adam

In particular, when combined with adaptive gradients, L2 regularization leads to weights with large historic parameters and gradient amplitudes being regularized less than they would be when using weight decay. L2 regularization is not effective in Adam.

One possible explanation why Adam and other adaptive gradient methods might be outperformed by SGD with momentum is that common deep learning libraries only implement L2 regularization, not the original weight decay. Therefore, on datasets where the use of L2 regularization is beneficial for SGD on many popular image classification datasets, Adam leads to worse results than SGD with momentum for which L2 regularization behaves as expected.