pytorch的no_grad()用法

no_grad() 方法是 PyTorch 中的一个上下文管理器,在进入该上下文管理器时禁止梯度的计算,从而减少计算的时间和内存,加速模型的推理阶段和参数更新。在推理阶段,只需进行前向计算,而不需要计算和保存每个操作的梯度。在参数更新时,我们只需要调整参数,并不需要计算梯度,而在训练阶段,需要进行反向传播以获取梯度,并对其进行参数更新。

使用 no_grad() 方法可以避免由于不必要的梯度计算而导致计算图占用过多的内存,从而降低了程序的性能。例如,以下代码将比其中不包含 no_grad() 的代码运行得快:

import torch

x = torch.randn(2, 2, requires_grad=True)
y = torch.randn(2, 2, requires_grad=True)

with torch.no_grad():
    z = x + y

猜你喜欢

转载自blog.csdn.net/weixin_40895135/article/details/130029385