反向传播:计算图的追踪与停止

原文链接:计算图的追踪与停止

反向传播梯度计算是模型收敛的必须手段,今天我们就看看PyTorch中反向传播是如何实现的。

1 反向传播的基本过程

x = torch.tensor(1., requires_grad = True)
y = x ** 2
z = y + 1

与上节一样,我们构建xyz三者之间的函数关系。

所谓反向传播,是在此前计算图中记录的函数关系中,反向传播函数关系,进而求得叶节点x的导数值。

z
# tensor(2., grad_fn=<AddBackward0>)

z.grad_fn
# <AddBackward0 at 0x7fe81279b048>
# add为y与z的对应关系

使用backward方法执行反向传播:

z.backward()

反向传播结束后,即可查看叶节点的导数值。张量的grad属性存储了该点处的导数值。

# x的值与可导属性
x
# tensor(1., requires_grad=True)

# x的导数值
x.grad
# tensor(2.)

因 为 z = y + 1 = x 2 + 1 因为z=y+1=x^2+1 z=y+1=x2+1

所 以 z ′ = 2 x 所以z' = 2x z=2x

又因为 x x x的取值为1,所以 x x x的导数值就是2。这就是计算图的反向传播。

注意

在默认情况下,在一张计算图上执行反向传播只能计算一次,再次调用backward方法将报错。

z.backward()
# RuntimeError:Traceback (most recent call last)

总结

反向传播的基本概念和使用方法:

  • 反向传播的本质:函数关系的反向传播
  • 反向传播的执行条件:拥有函数关系的可微分张量
  • 反向传播的函数作用:计算叶节点的导数/微分/梯度运算结果

2 中间节点的梯度保存

在默认情况下,我们只能计算叶节点的导数值。

x = torch.tensor(1.,requires_grad = True)
y = x ** 2
z = y ** 2
z.backward()

# 中间节点的梯度值不会被保存
y.grad
# 会报错

# 可求叶子节点的梯度值
x.grad
tensor(4.)

若想保存中间节点的梯度,我们可以使用retain_grad方法:

x = torch.tensor(1.,requires_grad = True)
y = x ** 2
y.retain_grad()
z = y ** 2
z.backward()

# 会记录x的值与相对应的函数关系
y
# tensor(1., grad_fn=<PowBackward0>)

y.grad
# tensor(2.)

3 阻止计算图追踪

在默认情况下,只要初始张量是可微分张量,系统就会自动追踪其相关运算,并保存在动态图计算图关系中,我们也可通过grad_fn来查看记录的函数关系。

x = torch.tensor(1.,requires_grad = True)
y = x ** 2
y.grad_fn
# <PowBackward0 at 0x7fe8103ba160>
# pow为x与y的对应关系

但在特殊的情况下,我们并不希望可微张量从创建到运算结果输出都被记录,此时就可以使用一些方法来阻止部分运算被记录。

  • with torch.no_grad()
x = torch.tensor(1.,requires_grad = True)
y = x ** 2

with torch.no_grad():
    z = y ** 2

with相当于一个上下文管理器,with torch.no_grad()内的代码都“屏蔽”了计算图的追踪记录:

z
# tensor(1.)

# 查看z的可导性
z.requires_grad
# False

# 而y不在其中,具备可导性
y
# tensor(1., grad_fn=<PowBackward0>)
  • detach()
    此方法的的原理是创建一个不可导的相同张量来阻止计算图的追踪。
x = torch.tensor(1.,requires_grad = True)
y = x ** 2
y1 = y.detach()
z = y1 ** 2

y
tensor(1., grad_fn=<PowBackward0>)

# 不具备可导的函数关系
y1
# tensor(1.)

# 不具备可导的函数关系
z
# tensor(1.)

4 识别叶节点

由于叶节点较为特殊,在默认情况下只能计算叶节点的导数值,我们可以使用is_leaf属性查看张量是否是叶节点。

x = torch.tensor(1.,requires_grad = True)
y = x ** 2
y1 = y.detach()

x.is_leaf
# True

y.is_leaf
# False

注意

任何一个新创建的张量,无论是否可导、是否加入计算图,都是可以是叶节点。

# 经过detach的新张量
y1.is_leaf
# True

# 新创建的张量
torch.tensor([1]).is_leaf
# True

这些叶节点只是不具备requires_grad属性。也就是说叶节点一定是新的张量,但是不一定可导,同学们注意不要被混淆。

原文链接:计算图的追踪与停止

猜你喜欢

转载自blog.csdn.net/Antai_ZHU/article/details/121904517