What are model.train() and model.eval()

Pytorch can provide us with two ways to switch between training and evaluation (inference) modes, namely: model.train() and model.eval().

1. model.train()

When using pytorch to build a neural network, a model.train() will be added above the program during the training process to enable BN (Batch Normalization) and dropout.

  • If there are BN layers (Batch Normalization) and Dropout in the model, you need to add model.train() during training.
  • model.train() is to ensure that the BN layer can use the mean and variance of each batch of data . For Dropout, model.train() randomly selects a part of network connections to train and update parameters.

2. model.eval()

The function of model.eval() is not to enable Batch Normalization and Dropout.

  • If there are BN layers (Batch Normalization) and Dropout in the model, add model.eval() when testing. You can not enable BatchNormalization and Dropout to ensure that BN and dropout do not change.
  • model.eval() is to ensure that the BN layer can use the mean and variance of all training data, that is, to ensure that the mean and variance of the BN layer remain unchanged during the test. For Dropout, model.eval() utilizes all network connections, that is, does not randomly discard neurons.
  • The pytorch framework will automatically fix BN and Dropout, and will not take the average, but use the trained value. Otherwise, once the batch_size of the test is too small, it is easy to be affected by the BN layer.

Why use model.eval() when testing?

After training the train samples, the generated model model is used to test the samples. Before model(test), you need to add model.eval(), otherwise, if there is input data, it will change the weight even if it is not trained. This is the nature of the BN layer and Dropout in the model.

When eval(), pytorch will automatically fix BN and DropOut, and will not take the average, but use the trained value.
Otherwise, once the batch_size of the test is too small, it is easy to be caused by the BN layer to cause great color distortion of the generated picture.
eval() needs to be added during non-training. Without this code, the values ​​of some network layers will change and will not be fixed. The results generated by your neural network are not fixed every time. The generated quality may be good or not. not good.

That is to say, use model.eval() during the test, then the neural network will continue to use the value of batch normalization, and dropout will not be used.

Batch Normalization

Its role is to normalize each layer in the middle of the network, and use Batch Normalization Transform to ensure that the feature distribution extracted by each layer will not be destroyed. The training is for each mini-batch, but the test is for a single picture, that is, there is no concept of batch. Since the parameters are fixed after the network training is completed, the mean and variance of each batch are unchanged, so the mean and variance of all batches are directly settled. All batch normalizations operate differently during training and testing.

Dropout

Its role is to overcome Overfitting. In each training batch, by ignoring half of the feature detectors, overfitting can be significantly reduced.

Summarize

Write model.trian() before training and model.eval() when testing

  • If there are BN layers (Batch Normalization) and Dropout in the model, you need to add model.train() during training and model.eval() during testing.
  • Among them, model.train() is to ensure that the BN layer uses the mean and variance of each batch of data, and model.eval() is to ensure that BN uses the mean and variance of all training data;
  • For Dropout, model.train() randomly selects a part of network connections to train and update parameters, while model.eval() utilizes all network connections.

Guess you like

Origin blog.csdn.net/weixin_45277161/article/details/130861935