torch.autograd.grad()详解


官网链接: torch.autograd.grad — PyTorch 2.0 documentation

torch.autograd.grad(outputs, inputs, grad_outputs=None, retain_graph=None, create_graph=False, only_inputs=True, allow_unused=False, is_grads_batched=False)

  • outputs:求导的因变量(需要求导的函数)。
  • inputs:求导的自变量。
  • grad_outputs:如果outputs为标量,则grad_outputs=None,也就是说,可以不用写; 如果outputs 是向量,则此参数必须写。
  • retain_graph: True 则保留计算图, False则释放计算图。
  • create_graph:若要计算高阶导数,则必须选为True。
  • allow_unused:允许输入变量不进入计算。

代码举例:

1.output为标量

output为标量时,不需要设置grad_outputs,即保持grad_outputs=None

import torch 
from torch import autograd
x = torch.rand(3, 4)
x.requires_grad_()
print(x)

# 对x中的所有元素求和,得到的结果是标量
y = torch.sum(x)  
print(y)

grads = autograd.grad(y, x)[0]  # 不加[0]原来是个元组
print(grads)

输出结果如下:

tensor([[0.2410, 0.9354, 0.4032, 0.6099],
        [0.1518, 0.7081, 0.5910, 0.8511],
        [0.1515, 0.6720, 0.4726, 0.5018]], requires_grad=True)  # x
tensor(6.2893, grad_fn=<SumBackward0>)  # 求和得到的y是个标量
tensor([[1., 1., 1., 1.],  # 得到的梯度值
        [1., 1., 1., 1.],
        [1., 1., 1., 1.]])

2.output为向量

output为向量时,求解梯度时,需要将grad_outputs设置为全1的、与output形状相同的张量,原因看上一篇博客。

y = x[:,0] + x[:,1]  # 每一行的前两列进行计算
print(y)

grad = autograd.grad(y, x, grad_outputs=torch.ones_like(y))[0]
print(grad)

输出结果:

tensor([1.1764, 0.8599, 0.8235], grad_fn=<AddBackward0>)  # y为向量
tensor([[1., 1., 0., 0.],  # 前两列的梯度值为1
        [1., 1., 0., 0.],
        [1., 1., 0., 0.]])

3.求二阶导数

需要设置create_graph=True才能计算二阶导数:

y = x ** 2
grad1 = autograd.grad(y, x, grad_outputs=torch.ones_like(y), create_graph=True)[0]
print(grad1)

grad2 = autograd.grad(grad1, x, grad_outputs=torch.ones_like(grad1))[0]
print(grad2)

输出结果:

tensor([[0.4819, 1.8708, 0.8063, 1.2198],
        [0.3035, 1.4163, 1.1819, 1.7023],
        [0.3029, 1.3440, 0.9451, 1.0037]], grad_fn=<MulBackward0>)
tensor([[2., 2., 2., 2.],
        [2., 2., 2., 2.],
        [2., 2., 2., 2.]])

可以看到 y = x 2 y=x^2 y=x2这个式子中求出的一阶导是 2 x 2x 2x,二阶导均为2。

参考博文:详解 pytorch 中的 autograd.grad() 函数_waitingwinter的博客-CSDN博客

猜你喜欢

转载自blog.csdn.net/qq_45670134/article/details/129695551