pytorch中model.train()和model.eval()的区别

model.train()和model.eval()的区别主要在于Batch Normalization和Dropout两层。

1、model.train()和model.eval()对应的源代码,如下所示,但是仅仅关注这一部分是不够的,现在需要记住当前的self.training的值是True还是False。

    def train(self, mode=True):
        r"""Sets the module in training mode.

        This has any effect only on certain modules. See documentations of
        particular modules for details of their behaviors in training/evaluation
        mode, if they are affected, e.g. :class:`Dropout`, :class:`BatchNorm`,
        etc.

        Returns:
            Module: self
        """
        self.training = mode
        for module in self.children():
            module.train(mode)
        return self
       
	 def eval(self):
        return self.train(False)

2、下边以Dropout为例,进入其对应的源代码,下方对应的self.training就是第一步中的self.training,原因在于Dropout继承了 _DropoutNd类,而 _DropoutNd由继承了Module类,Module类中自带变量self.training,通过这种方法,来控制train/eval模型下是否进行Dropout。

class Dropout(_DropoutNd):
    '''
    	balabala
    '''

    @weak_script_method
    def forward(self, input):
        return F.dropout(input, self.p, self.training, self.inplace)

猜你喜欢

转载自blog.csdn.net/tailonh/article/details/111213211
今日推荐