During training, the neural net iteratively modifies the weights to minimize the errors we make in the training examples. After enough training, we expect that our neural network will be pretty effective at solving the task it’s been trained to do.
In addition to being able to define the model, instantiate it, and run data through it, we must be able to train and test the model. Here, we implement a practical example in PyTorch—a classifier for the MNIST digits dataset, complete with code for training the classifier:
for epoch in range(10):
model.train()
running_loss = 0.0
for i, data in enumerate(trainloader, 0):
# get the inputs; data is a list of [inputs, labels]
inputs, labels = data
# zero the parameter gradients
optimizer.zero_grad()
# forward + backward + optimize
outputs = net(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
# print statistics
running_loss += loss.item()
if i % 2000 == 1999: # print every 2000 mini-batches
print(f'[{epoch + 1}, {i + 1:5d}] loss: {running_loss / 2000:.3f}')
running_loss = 0.0
print('Finished Training')
To train the model, we need a loss metric to evaluate the model. During training, once we calculate this loss metric, we can use our knowledge from the previous section and call backward() on the computed loss. This will store the gradient in each parameter p’s grad attribute.
Additionally, note that we call model.train()
at the beginning of the training functions. The calls to these functions communicate to the PyTorch backend whether the model is in training mode or inference mode. You might wonder why we need to call model.train()
and model.eval()
if there is no difference between the behavior of the neural network at train and test time. The training and testing modes for neural architectures are not necessarily the same.
For example, dropout is normally active during training, while during the evaluation, dropout is bypassed or, equivalently, assigned a probability equal to zero. This is controlled through the training property of the Dropout module. PyTorch lets us switch between the two modalities by calling model.train()
or model.eval()
.
On any nn.Model
subclass. The call will be automatically replicated on the submodules so that if Dropout is among them, it will behave accordingly in subsequent forward and backward passes.
Batch normalization must behave differently during training and inference. In fact, at inference time, we want to avoid having the output for a specific input depend on the statistics of the other inputs we’re presenting to the model.
As mini-batches are processed, in addition to estimating the mean and standard deviation for the current minibatch, PyTorch also updates the running estimates for mean and standard deviation that are representative of the whole dataset, as an approximation.
When the user specifies model.eval() and the model contains a batch normalization module, the running estimates are frozen and used for normalization. To unfreeze running estimates and return to using the minibatch statistics, we call model.train()
, just as we did for dropout.