model.eval()和with torch.no_grad()

model.eval()和with torch.no_grad()的区别

先给出结论: 这两个的所起的作用是不同的。

  • model.eval

    model.eval()的作用是不启用Batch Normalization 和 Dropout. 相当于告诉网络,目前在eval模式,dropout层不会起作用(而在训练阶段,dropout会随机以一定的概率丢弃掉网络中间层的神经元,而且在实际操作过程中一般不在卷积层设置该操作,而是在全连接层设置,因为全连接层的参数量一般远大于卷积层),batch normalization也有不同的作用(在训练过程中会对每一个特征维做归一化操作,对每一批量输入算出mean和std,而在eval模式下BN层将能够使用全部训练数据的均值和方差,即测试过程中不再针对测试样本计算mean和std,而是用训练好的值)。

  • with torch.no_grad()

    当我们计算梯度的时候,我们需要缓存前向传播过程中大量的中间输出,因为在反向传播pytoch自动计算梯度时需要用到这些值。而我们在测试时,我们不需要计算梯度,那么也就意味着我们不需要在forward的时候保存这些中间输出。此外,在测试阶段,我们也不需要构造计算图(这也需要一定的存储开销)。Pytorch为我们提供了一个上下文管理器,torch.no_grad,在with torch.no_grad() 管理的环境中进行计算,不会生成计算图,不会存储为计算梯度而缓存的中间值。

总结: 当网络中出现batch normalization或者dropout这样的在training,eval时表现不同的层,应当使用model.eval()。在测试时用with torch.no_grad()会节省存储空间。

另外需要注意的是,即便不使用with torch.no_grad(),在测试只要你不调用loss.backward()就不会计算梯度,with torch.no_grad()的作用只是节省存储空间。当然,在测试阶段,可以两者一起使用,效果更好。

猜你喜欢

转载自blog.csdn.net/weixin_46202290/article/details/127777510
今日推荐