详解Pytorch动态图的回溯机制

原文链接:详解Pytorch动态图的回溯机制

大家好,我是泰哥。《5分钟精通PyTorch》经过1个月的连载,已经介绍了张量的常规操作以及运算技巧。之后的章节就进入到深度学习部分,会以理论与代码结合的方式为大家呈现,帮助大家理解其中细节。

对于动态图回溯机制的学习与理解,首先从张量的微分计算开始入手。

注意

本节我们暂时不区分微分值、导数值、梯度值的区别,后续讲解梯度下降时再进行区分。不理解的同学统一理解为导数即可。

1 Variablerequires_grad

有同学会说在进行微分运算时需提前将Tensor类转化为Variable类,但其实在PyTorch 0.4版本后,Variable的概念被逐渐弱化,Tensor就不再是一个纯计算载体,可微分性也变成了Tensor的一个基本属性,我们只需要在创建Tensor时,通过设置requires_grad属性为True,就可规定张量可微分计算。

x = torch.tensor(1.,requires_grad = True)
x
# tensor(1., requires_grad=True)

此时张量a就是一个可微分的张量,requires_grad是它的一个属性,可以查看并修改。

# 查看可微分性
x.requires_grad
# True

# 修改可微分性
x.requires_grad = False

# 再次查看可微分性
x.requires_grad
# False

2 可微分性的属性

可微分性会体现在可微分张量参与的所有运算中。

  • requires_grad属性:可微分性
# 构建可微分张量
x = torch.tensor(1., requires_grad = True)
x
# tensor(1., requires_grad=True)

# 构建函数关系
y = x ** 2
y
# tensor(1., grad_fn=<PowBackward0>)

我们发现,此时张量y具有了一个grad_fn属性,并且取值为<PowBackward0>,我们可以查看该属性:

y.grad_fn
# <PowBackward0 at 0x200a2047208>

grad_fn存储了Tensor的微分函数,也可以说它存储了可微分张量在进行计算过程中的函数关系,此处xy进行了幂运算,就是pow方法,与上述返回属性相对应。

# 但x作为初始张量,并没有grad_fn属性
x.grad_fn

值得注意的是,y不仅和x存在幂运算关系 y = x 2 y = x^2 y=x2,更重要的是,y本身是一个由x张量计算得出的一个张量:

# 打印y
y
# tensor(1., grad_fn=<PowBackward0>)

对于一个可微分张量(x)生成的张量(y),也是可微分的:

y.requires_grad
# True

相比于xy不仅同样拥有张量的取值,也同样可微,还额外存储了x到y的函数计算信息。

我们再尝试围绕y创建新的函数关系:z = y + 1:

z = y + 1
z
# tensor(2., grad_fn=<AddBackward0>)

# z同样可微
z.requires_grad
# True

# z保存了y到z的函数函数add
z.grad_fn
# <AddBackward0 at 0x200a2037648>

可以发现z也存储了数值,并是可微的,还存储了和y的计算关系(add)。

3 回溯机制

PyTorch的张量计算过程中,如果我们设置初始张量是可微的,则在计算过程中,每一个由原张量计算得出的新张量都是可微的,并且还会保存此前一步的函数关系,这也就是所谓的回溯机制。

根据这个回溯机制,我们就能清楚地掌握张量每一步的计算过程,并据此绘制张量计算图。

4 张量计算图

借助回溯机制,我们就能将张量的复杂计算过程抽象为一张图Graph,比如我们之前定义的xyz三个张量,三者的计算关系就可以由下图进行表示。

计算图定义

计算图模型由节点nodes和边edges构成,节点表示操作符,也就是张量,节点之间的边表示张量之间的函数关系,方向则表示实际运算方向。

节点类型

在张量计算图中,虽然每个节点都表示可微分张量,但节点和节点之间却略有不同,比如在前例中:

  • yz保存了函数计算关系,但x没有
  • 可以发现z是所有计算的终点

因此可以将节点分为三类,分别是:

  1. 叶节点:初始输入的可微分张量,前例中的x
  2. 输出节点:最后计算得出的张量,前例中的z
  3. 中间节点:在一张计算图中,除了叶节点和输出节点,其他都是中间节点,前例中的y

在一张计算图中,可以有多个叶节点和中间节点,但大多数情况下,只有一个输出节点,若存在多个输出结果,我们也往往会将其保存在一个向量中。

5 计算图的动态性

PyTorch的计算图是动态计算图,会根据可微分张量的计算过程自动生成,并且伴随着新张量或运算的加入不断更新,这使得PyTorch的计算图更加灵活高效且容易构建。

而静态图(TF1)需要先先构建计算流程图,然后进行会话传入数据后才能执行。具体代码比较与分析可看TF与PyTorch该如何选择

原文链接:详解Pytorch动态图的回溯机制

猜你喜欢

转载自blog.csdn.net/Antai_ZHU/article/details/121904172