分类目录:《深入浅出Pytorch函数》总目录
禁用梯度计算的上下文管理器。当我们确信不会调用Tensor.backward()
时,禁用梯度计算对推理很有用。它将减少计算的内存消耗,否则我们需要设置requires_grad=True
。在这种模式下,即使输入的requires_grad
为True
,每次计算的结果也将为requires_grad=False
。这个上下文管理器是线程本地的,它不会影响其他线程中的计算。同时,这个类也可以起到装饰器的作用。
语法
torch.no_grad()
实例
x = torch.tensor([1.], requires_grad=True)
with torch.no_grad():
y = x * 2
y.requires_grad
.no_grad()
def doubler(x):
return x * 2
z = doubler(x)
z.requires_grad
函数实现
class no_grad(_DecoratorContextManager):
r"""Context-manager that disabled gradient calculation.
Disabling gradient calculation is useful for inference, when you are sure
that you will not call :meth:`Tensor.backward()`. It will reduce memory
consumption for computations that would otherwise have `requires_grad=True`.
In this mode, the result of every computation will have
`requires_grad=False`, even when the inputs have `requires_grad=True`.
This context manager is thread local; it will not affect computation
in other threads.
Also functions as a decorator. (Make sure to instantiate with parenthesis.)
.. note::
No-grad is one of several mechanisms that can enable or
disable gradients locally see :ref:`locally-disable-grad-doc` for
more information on how they compare.
.. note::
This API does not apply to :ref:`forward-mode AD <forward-mode-ad>`.
If you want to disable forward AD for a computation, you can unpack
your dual tensors.
Example::
>>> # xdoctest: +SKIP
>>> x = torch.tensor([1.], requires_grad=True)
>>> with torch.no_grad():
... y = x * 2
>>> y.requires_grad
False
>>> @torch.no_grad()
... def doubler(x):
... return x * 2
>>> z = doubler(x)
>>> z.requires_grad
False
"""
def __init__(self) -> None:
if not torch._jit_internal.is_scripting():
super().__init__()
self.prev = False
def __enter__(self) -> None:
self.prev = torch.is_grad_enabled()
torch.set_grad_enabled(False)
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
torch.set_grad_enabled(self.prev)