The difference between model.train() and model.eval() in pytorch

The difference between model.train() and model.eval() is mainly in the two layers of Batch Normalization and Dropout.

1. The source code corresponding to model.train() and model.eval() is shown below, but it is not enough to focus on this part. Now you need to remember whether the current value of self.training is True or False.

    def train(self, mode=True):
        r"""Sets the module in training mode.

        This has any effect only on certain modules. See documentations of
        particular modules for details of their behaviors in training/evaluation
        mode, if they are affected, e.g. :class:`Dropout`, :class:`BatchNorm`,
        etc.

        Returns:
            Module: self
        """
        self.training = mode
        for module in self.children():
            module.train(mode)
        return self
       
	 def eval(self):
        return self.train(False)

2. Take Dropout as an example below, enter its corresponding source code, the corresponding self.training below is self.training in the first step, the reason is that Dropout inherits the _DropoutNd class, and _DropoutNd inherits the Module class, Module The class has its own variable self.training, through this method, you can control whether to perform Dropout under the train/eval model.

class Dropout(_DropoutNd):
    '''
    	balabala
    '''

    @weak_script_method
    def forward(self, input):
        return F.dropout(input, self.p, self.training, self.inplace)

Guess you like

Origin blog.csdn.net/tailonh/article/details/111213211