PyTorch中的自动求导系统autograd

  深度学习模型的训练,就是不断更新权值,权值的更新需要求解梯度,求解梯度十分繁琐,PyTorch提供自动求导系统,我们只要搭建好前向传播的计算图,就能获得所有张量的梯度。

torch.autograd.backward()

torch.autograd.backward(tensors, 
                        grad_tensors=None, 
                        retain_graph=None, 
                        create_graph=False)

功能: 自动求取梯度

  • tensors: 用于求导的张量,如 loss
  • retain_graph : 保存计算图,由于PyTorch采用动态图机制,在每次反向传播之后计算图都会释放掉,如果还想继续使用,就要设置此参数为True
  • create_graph : 创建导数计算图,用于高阶求导
  • grad_tensors:多梯度权重,当有多个loss需要计算梯度时,需要设置每个loss的权值
w = torch.tensor([1.], requires_grad=True)
x = torch.tensor([2.], requires_grad=True)
a = torch.add(x, w)
b = torch.add(w, 1)
y = torch.mul(a, b)
y.backward()
print(w.grad)

调试:
在 y.backward() 处设置断点
点击 step into 进入方法,可以看到方法中只有一行,说明 y.backward() 直接调用了torch.autograd.backward()

torch.autograd.backward(self, gradient, retain_graph, create_graph)

点击单步调试 step over 返回 y.backward()
停止调试
多次执行y.backward()会报错,因为计算图被释放,解决方法是第一次反向传播时 y.backward(retain_graph=True)

w = torch.tensor([1.], requires_grad=True)
x = torch.tensor([2.], requires_grad=True)
a = torch.add(x, w)
b = torch.add(w, 1)
y = torch.mul(a, b)
y.backward()  # 正确写法 y.backward(retain_graph=True)
y.backward()  # 再执行一次反向传播
RuntimeError: Trying to backward through the graph a second time, but the buffers have already been freed. Specify retain_graph=True when calling backward the first time.

grad_tensors参数的用法

w = torch.tensor([1.], requires_grad=True)
x = torch.tensor([2.], requires_grad=True)
a = torch.add(x, w)
b = torch.add(w, 1)
y0 = torch.mul(a, b)
y1 = torch.add(a, b)
loss = torch.cat([y0, y1], dim=0)
grad_t = torch.tensor([1., 2.])
loss.backward(gradient=grad_t)
print(w.grad)
tensor([9.])

说明:
y 0 = ( x + w ) × ( w + 1 ) y 0 w = 5 y0 = ( x + w) \times ( w + 1 ),\frac{\partial y_0}{\partial w}=5
y 1 = ( x + w ) + ( w + 1 ) d y 0 d w = 2 y1 = ( x + w) + ( w + 1 ),\frac{dy_0}{dw}=2
w . g r a d = y 0 × 1 + y 1 2 = 5 + 2 × 2 = 9 w.grad=y0 \times 1+y1*2 =5+2\times2=9

torch.autograd.grad()

torch.autograd.grad(outputs, 
                    inputs, 
                    grad_outputs=None, 
                    retain_graph=None, 
                    create_graph=False)

功能: 求取梯度

  • outputs: 用于求导的张量,如上例中的 loss
  • inputs : 需要梯度的张量,如上例中的w
  • create_graph : 创建导数计算图,用于高阶 求导
  • retain_graph : 保存计算图
  • grad_outputs:多梯度权重

计算 y = x 2 y=x^2 的二阶导数

x = torch.tensor([3.], requires_grad=True)
y = torch.pow(x, 2)
grad1 = torch.autograd.grad(y, x, create_graph=True)  # create_graph=True 创建导数的计算图,实现高阶求导
print(grad1)
grad2 = torch.autograd.grad(grad1[0], x)
print(grad2)
(tensor([6.], grad_fn=<MulBackward0>),)
(tensor([2.]),)

小贴士:

  1. 梯度不自动清零,在每次反向传播中会叠加
w = torch.tensor([1.], requires_grad=True)
x = torch.tensor([2.], requires_grad=True)
for i in range(3):
    a = torch.add(x, w)
    b = torch.add(w, 1)
    y = torch.mul(a, b)
    y.backward()
    print(w.grad)
tensor([5.])
tensor([10.])
tensor([15.])

这导致我们得不到正确的结果,所以需要手动清零

w = torch.tensor([1.], requires_grad=True)
x = torch.tensor([2.], requires_grad=True)
for i in range(3):
    a = torch.add(x, w)
    b = torch.add(w, 1)
    y = torch.mul(a, b)
    y.backward()
    print(w.grad)
    w.grad.zero_()  # 梯度清零
tensor([5.])
tensor([5.])
tensor([5.])
  1. 依赖于叶子结点的结点,requires_grad默认为True
w = torch.tensor([1.], requires_grad=True)
x = torch.tensor([2.], requires_grad=True)
a = torch.add(x, w)
b = torch.add(w, 1)
y = torch.mul(a, b)
print(a.requires_grad, b.requires_grad, y.requires_grad)
True True True
  1. 叶子结点不可执行in-place,因为前向传播记录了叶子节点的地址,反向传播需要用到叶子节点的数据时,要根据地址寻找数据,执行in-place操作改变了地址中的数据,梯度求解也会发生错误。
w = torch.tensor([1.], requires_grad=True)
x = torch.tensor([2.], requires_grad=True)
a = torch.add(x, w)
b = torch.add(w, 1)
y = torch.mul(a, b)
w.add_(1)
RuntimeError: a leaf Variable that requires grad has been used in an in-place operation.

in-place操作,即原位操作,在原始内存中改变这个数据,方法后接_代表in-place操作

a = torch.tensor([1])
print(id(a), a)
a = a + torch.tensor([1])  # 开辟了新的内存地址
print(id(a), a)
a += torch.tensor([1])  # in-place操作,地址不变
print(id(a), a)
3008015174696 tensor([1])
3008046791240 tensor([2])
3008046791240 tensor([3])
发布了9 篇原创文章 · 获赞 0 · 访问量 294

猜你喜欢

转载自blog.csdn.net/SakuraHimi/article/details/104598194