model.train() and model.eval() in Pytorch

After training for an epoch, we may generate a model for testing. Before testing, you need to add model.eval(), otherwise, the weight of the model will change even if it is not trained. This is because there are Batch Normalization layers and Dropout layers in the model.

model.train()和model.eval()

We know that in pytorch, the model has two modes that can be set, one is train mode and the other is eval mode.

model.train(): The function is to enable Batch Normalization and Dropout. In the train mode, the Dropout layer will set the probability of retaining the activation unit according to the set parameter p, such as keep_prob=0.8, and the Batch Normalization layer will continue to calculate the mean and var of the data and update them.

model.eval(): The function is not to enable Batch Normalization and Dropout. In eval mode, the Dropout layer will pass all activation units, while the Batch Normalization layer will stop calculating and updating mean and var, and directly use the mean and var values ​​that have been learned during the training phase. When using model.eval(), it is to switch the model to the test mode. Here, the model will not update the weights like in the training mode. However, it should be noted that model.eval() will not affect the gradient calculation behavior of each layer, that is, it will perform gradient calculation and storage in the same way as the training mode, but it will not perform backpropagation.

# 针对BN层:
model.train() # 是保证BN层用每一批数据的均值和方差,即针对每个mini-batch的 ;
model.eval()  # 是保证BN用全部训练数据的均值和方差,即针对单张图片的;
# 针对Dropout层
model.train() # 随机取一部分网络连接来训练更新参数;
model.eval()  # 利用到了所有网络连接;

model.eval()和torch.no_grad()

When talking about model.eval(), torch.no_grad() is actually mentioned.

torch.no_grad(): It is used to stop the calculation of autograd, which can speed up and save video memory, but it will not affect the behavior of the Dropout layer and the Batch Normalization layer.

If you don't care about the memory size and calculation time, just using model.eval() is enough to get the correct validation result; and

with torch.zero_grad(): It is to further accelerate and save gpu space. Because there is no need to calculate and store gradients, it can be calculated faster and use larger batches to run the model.

For a detailed analysis, you can take a look at: [PyTorch] Get the model.train() and model.eval() modes in network training- Know about

Guess you like

Origin blog.csdn.net/ytusdc/article/details/128523707