In PyTorch, the model training mode model.train()
and the model testing mode model.eval()
are used to turn on and off the training mode and test mode of the model respectively.
-
model.train()
The model will be set to training mode, enabling training-specific operations such as Dropout and Batch Normalization. This mode is suitable for the training phase. Since Dropout randomly turns off neurons in each iteration, it can reduce the interdependence between neurons and make the model generalization ability stronger. In addition, Batch Normalization can normalize the input data, weaken the interaction between each feature, and speed up the convergence of the model. -
model.eval()
The model will be set to test mode, and operations specific to training such as Dropout and Batch Normalization will be turned off. This mode is suitable for the testing phase. In the testing phase, we usually focus on the output results of the model, rather than the Dropout or Batch Normalization operations inside the model. Therefore, during the testing phase, we need to turn off these operations and perform forward calculation and output of the model.
In practical applications, we usually need to switch modes dynamically during model training and testing. For example, during training, we need to enable model.train()
training mode and enable some training-specific operations; while during testing, we need to enable model.eval()
testing mode and disable some training-specific operations to obtain more accurate test results .
When using model.eval()
, also need to pay attention to the following points:
model.eval()
is an in-place operation that does not return any value, it just changes the state of the model.- When used
model.eval()
, neither the parameters nor the cache in the model will change, which prevents unnecessary computation and memory consumption during testing. - When evaluating the model, it is usually necessary to set the mean and variance in the Batch Normalization layer to fixed values to ensure that the statistical characteristics of the test data and training data are the same. At this point, we can use
torch.no_grad()
a context manager and use itmodel.eval()
withtorch.no_grad()
For example:
with torch.no_grad():
model.eval()
for inputs, labels in test_loader:
outputs = model(inputs)
...
with torch.no_grad()
The context manager is used here to wrap the entire testing process, while model.eval()
setting the model to test mode. In this way, we can turn off the gradient calculation and Batch Normalization operations during the test, and ensure that the statistical characteristics of the test data and training data are the same.