手搓GPT系列之 - 后向传播,计算图,目标函数

本问将介绍神经网络中后向传播的机制和基本原理。详细解析在后向传播过程中,计算图的生成,以及如何在计算图中应用链式规则实现自动求导的机制,并介绍了价值函数(又称损失函数)在后向传播过程中的作用。适合初步了解神经网络基本概念的同学进一步理解神经网络参数优化的过程。

当我们在写一些分类模型的训练代码的时候,总是要写一句:

loss.backward()

这句简单的代码启动了一个精妙的过程,误差信号以网络的出口为起点,逆着计算图的边传遍网络中的每一个节点,网络中的任何一个几点可以根据误差信号来调整自己的参数从而在下次预测中取得更好的效果。我们在上一篇文章中介绍了神经网络的一些基本概念,就已经介绍过计算图。我们把那个例子再拿过来。

1. 计算图

计算图的内容可以参考前文。我们这里再把前面的例子拿出来。假设我们有一个单层的神经网络模型,我们用 x x x表示输入特征值向量, W W W表示输出层权重矩阵, b b b表示输出层偏移量向量, σ \sigma σ表示激活函数, a a a表示输出结果向量。该模型的输出可以用公式表示为:
a = σ ( z ) z = y + b y = W T x a=\sigma(z)\\ z=y + b\\ y= W^Tx a=σ(z)z=y+by=WTx
上边这个公式可以表示为如下计算图:
在这里插入图片描述

2 利用计算图和链式规则自动计算梯度

所谓后向传播,其目的就是算出新一轮迭代中模型各个参数的梯度值。其中最关键的步骤就是求模型关于各个参数的偏导。有些小伙伴内心开始要犯嘀咕了,OMG,高数的恶梦又要来了。其实只要你不被自己吓倒,这玩意真没你想的那么难。准备好了我们就开车吧!

2.1 链式规则

先介绍一下链式规则。如果我们对函数 f ( g ( x ) ) f(g(x)) f(g(x))求关于 x x x的导数的时候,我们可以应用链式规则:
∂ f ( g ( x ) ) ∂ x = ∂ f ( g ( x ) ) ∂ g ( x ) ⋅ ∂ g ( x ) ∂ x \frac{\partial f(g(x))}{\partial x} = \frac{\partial f(g(x))}{\partial g(x)} \cdot \frac{\partial g(x)}{\partial x} xf(g(x))=g(x)f(g(x))xg(x)

2.2 将链式规则应用到计算图中

我们以上边这个模型为例子来介绍如何通过链式规则求偏导。这个模型一共有两个参数 W W W b b b,我们先来看看关于 W W W的偏导:
∂ σ ( z ) ∂ W = ∂ σ ( z ) ∂ z ⋅ ∂ z ∂ y ⋅ ∂ y ∂ W \frac{\partial \sigma(z)}{\partial W} = \frac{\partial \sigma(z)}{\partial z} \cdot \frac{\partial z}{\partial y} \cdot \frac{\partial y}{\partial W} Wσ(z)=zσ(z)yzWy
我们再来看关于参数 b b b的偏导:
∂ σ ( z ) ∂ b = ∂ σ ( z ) ∂ z ⋅ ∂ z ∂ b \frac{\partial \sigma(z)}{\partial b} = \frac{\partial \sigma(z)}{\partial z} \cdot \frac{\partial z}{\partial b} bσ(z)=zσ(z)bz

我们把上边式子右侧的因子分别标记到计算图中,则有:

在这里插入图片描述
我们可以看到,上图中,除了 a a a节点意外,每个数据节点(圆形节点)都分到了一个偏导函数,这个偏导函数就是用于计算该节点的本地梯度。而途中任意一个数据节点的下游梯度等于上游节点的下游梯度与该节点的本地梯度之积。 a a a点作为梯度计算的起点,我们计算其下游梯度值时直接给1。
我们发现,上边两个式子有一个共同的因子 ∂ σ ( z ) ∂ z \frac{\partial \sigma(z)}{\partial z} zσ(z),从上图可以看出,从 a a a点后向传播到 W W W b b b,都需要经过 z z z点。因此 z z z点的梯度值计算一次之后,就可以保存起来分别给 W W W b b b的梯度计算使用,减少重复计算的成本。

2.3 几种计算节点的类型及其梯度的计算

下边介绍几种计算节点的类型。

2.3.1 单输入单输出

单输入和单输出的梯度计算是最简单的。这种节点可以表示为 y = f ( x ) y=f(x) y=f(x),我们用 δ \delta δ来表示上游的梯度,如下图所示:

扫描二维码关注公众号,回复: 15050933 查看本文章

在这里插入图片描述

此时 x x x的下游梯度为 δ ⋅ ∂ f ( x ) ∂ x \delta \cdot \frac{\partial f(x)}{\partial x} δxf(x)

2.3.2 多输入单输出

我们举一个简单的两输入一输出的例子:
在这里插入图片描述

在多输入的节点中,每个输入其实就是一个变量。传播到该节点的梯度值就是关于该变量的梯度。因此上图中关于 x x x的下游梯度为 δ ⋅ ∂ f ( x , y ) ∂ x \delta \cdot \frac{\partial f(x,y)}{\partial x} δxf(x,y)。关于 y y y的下游梯度为 δ ⋅ ∂ f ( x , y ) ∂ y \delta \cdot \frac{\partial f(x,y)}{\partial y} δyf(x,y)

2.3.2 单输入多输出

对于一个数据节点作为多个计算节点的输入,会获得多个上游梯度。如下图所示:
在这里插入图片描述

此时关于 x x x的下游梯度值为 δ 1 + δ 2 \delta_1 + \delta_2 δ1+δ2

3 目标函数

上边这个计算图看起来已经相当完美了,还差什么呢?慢着,如果这就完成了,那么神经网络模型中的训练数据的标签值应该放在哪里?是不是有一种把自行车拆开了又装回去,结果发现地上还差一把螺丝没拧完的感觉啊?的确,还有最后一块拼图:目标函数。关于目标函数,笔者有一篇更详细的解读,各位读者可以移步这里:深入理解Linear Regression,Softmax模型的损失函数。目标函数就是一个评判模型预测结果与测试数据集标签之间差异大小的函数。差异越大,目标函数值越大,反之目标函数值越小。因此目标函数也经常被称为损失函数。在后向传播中,我们要把损失函数挂在计算图的最末尾:

在这里插入图片描述
现在可算是终于把剩下的螺丝拧上了。

4. 小结

链式规则和计算图的结合使得自动求偏导能够在一个任意搭建的网络中实现,任何一个节点只要知道其上游的梯度,就能与自己的本地梯度一起计算出自己的下游梯度值,并传递给下游。我们目前主要的深度计算框架包括pytorch,tensorflow等,其核心都是实现了这个计算图和后向传播过程,然后每个节点的前向和后向传播逻辑,则交给生态来完成。现在的各种计算节点的前向和后向方法,已经由一些组织和个人实现了,我们只需要直接使用就可以,非常方便。

猜你喜欢

转载自blog.csdn.net/marlinlm/article/details/130096898