深入浅出Pytorch函数——torch.no_grad

分类目录:《深入浅出Pytorch函数》总目录


禁用梯度计算的上下文管理器。当我们确信不会调用Tensor.backward()时,禁用梯度计算对推理很有用。它将减少计算的内存消耗,否则我们需要设置requires_grad=True。在这种模式下,即使输入的requires_gradTrue,每次计算的结果也将为requires_grad=False。这个上下文管理器是线程本地的,它不会影响其他线程中的计算。同时,这个类也可以起到装饰器的作用。

语法

torch.no_grad()

实例

x = torch.tensor([1.], requires_grad=True)
with torch.no_grad():
    y = x * 2
y.requires_grad
.no_grad()
def doubler(x):
    return x * 2
z = doubler(x)
z.requires_grad

函数实现

class no_grad(_DecoratorContextManager):
    r"""Context-manager that disabled gradient calculation.

    Disabling gradient calculation is useful for inference, when you are sure
    that you will not call :meth:`Tensor.backward()`. It will reduce memory
    consumption for computations that would otherwise have `requires_grad=True`.

    In this mode, the result of every computation will have
    `requires_grad=False`, even when the inputs have `requires_grad=True`.

    This context manager is thread local; it will not affect computation
    in other threads.

    Also functions as a decorator. (Make sure to instantiate with parenthesis.)

    .. note::
        No-grad is one of several mechanisms that can enable or
        disable gradients locally see :ref:`locally-disable-grad-doc` for
        more information on how they compare.

    .. note::
        This API does not apply to :ref:`forward-mode AD <forward-mode-ad>`.
        If you want to disable forward AD for a computation, you can unpack
        your dual tensors.

    Example::
        >>> # xdoctest: +SKIP
        >>> x = torch.tensor([1.], requires_grad=True)
        >>> with torch.no_grad():
        ...     y = x * 2
        >>> y.requires_grad
        False
        >>> @torch.no_grad()
        ... def doubler(x):
        ...     return x * 2
        >>> z = doubler(x)
        >>> z.requires_grad
        False
    """
    def __init__(self) -> None:
        if not torch._jit_internal.is_scripting():
            super().__init__()
        self.prev = False

    def __enter__(self) -> None:
        self.prev = torch.is_grad_enabled()
        torch.set_grad_enabled(False)

    def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
        torch.set_grad_enabled(self.prev)

猜你喜欢

转载自blog.csdn.net/hy592070616/article/details/132029988