model.eval()和model.train()

reference:

  • PyTorch study notes: nn.Dropout - random drop layer``
  • What exactly do model.train() and model.eval() in pytorch do? https://www.zhihu.com/question/429337764

Two models of the x.1 model

The model has two modes, one is train mode and the other is eval mode. Specified by model.train()and .model.eval()

The role of model.train() is to enable Batch Normalization and Dropout. In the train mode, the Dropout layer will set the probability of retaining the activation unit according to the set parameter p, such as keep_prob=0.8, and the Batch Normalization layer will continue to calculate the mean and var of the data and update them.

The function of model.eval() is not to enable Batch Normalization and Dropout. In eval mode, the Dropout layer will pass all activation units, while the Batch Normalization layer will stop calculating and updating mean and var, and directly use the mean and var values ​​that have been learned during the training phase.

When using model.eval(), it is to switch the model to the test mode. Here, the model will not update the weights like in the training mode. However, it should be noted that model.eval() will not affect the gradient calculation behavior of each layer, that is, it will perform gradient calculation and storage like the training mode, but it will not perform backpropagation.

x.2 Dropout

The dropout temporary retreat method, the book says that neurons are randomly inactivated. In terms of code implementation, in fact, the output element is set to 0 according to the proportion of the elements passing through the layer, which makes the neuron at the position of the 0 element set to 0 in the subsequent network model training.

x.3 Difference between torch.no_grad() and model.eval()

torch.no_grad() is global information, often used in conjunction with with, to turn off automatic differentiation in the following code. It can be checked by torch.is_grad_enable()viewing, if it is True, the gradient will be tracked.

model.eval() is for the model, and sets the model.training property of the model and its children to False.

Guess you like

Origin blog.csdn.net/qq_43369406/article/details/131454423