RNN反向求导详解

(这里是本章会用到的 GitHub 地址)

(感谢评论区

指出本文的诸多错误!!真的非常感谢!!【拜】)

RNN 的“前向传导算法”

在说明如何进行训练之前,我们先来回顾一下 RNN 的“前向传导算法。在上一章中曾经给过一个没有激活函数和变换函数的公式:

\begin{align} o_{1} &= Vs_{1} = V\left( Ux_{1} \right) = x_{1} \\ o_{2} &= Vs_{2} = V\left( Ws_{1} + Ux_{2} \right) = 2s_{1} + x_{2} \\ \ldots \\ o_{t} &= Vs_{t} = V\left( Ws_{t - 1} + Ux_{t} \right) = 2s_{t - 1} + x_{t} \end{align}

在实现层面来说,这就是一个循环的事儿,所以代码写起来会比较简单:

import numpy as np

class RNN1:
    def __init__(self, u, v, w):
        self._u, self._v, self._w = np.asarray(u), np.asarray(v), np.asarray(w)
        self._states = None

    # 激活函数
    def activate(self, x):
        return x

    # 变换函数
    def transform(self, x):
        return x

    def run(self, x):
        output = []
        x = np.atleast_2d(x)
        # 初始化 States 矩阵为零矩阵
        # 之所以把所有 States 记下来、是因为训练时(BPTT 算法)要用到
        self._states = np.zeros([len(x)+1, self._u.shape[0]])
        for t, xt in enumerate(x):
            # 对着公式敲代码即可 ( σ'ω')σ
            self._states[t] = self.activate(
                self._u.dot(xt) + self._w.dot(self._states[t-1])
            )
            output.append(self.transform(
                self._v.dot(self._states[t]))
            )
        return np.array(output)

可以用上一章说过的那个小栗子来测试一下:

  • 假设现在U,V是单位阵,W是单位阵的两倍
  • 假设输入序列为:\left( 1,0,0,\ldots,0 \right)^{T} \rightarrow \left( 0,1,0,\ldots,0 \right)^{T} \rightarrow \left( 0,0,1,\ldots,0 \right)^{T} \rightarrow \ldots \rightarrow \left( 0,0,0,\ldots,1 \right)^{T}

对应的测试代码如下:

n_sample = 5
rnn = RNN1(np.eye(n_sample), np.eye(n_sample), np.eye(n_sample) * 2)
print(rnn.run(np.eye(n_sample)))

程序输出为:


这和我们上一章推出的理论值\left( 1,0,0,\ldots,0 \right)^{T} \rightarrow \left( 2,1,0,\ldots,0 \right)^{T} \rightarrow \left( 4,2,1,\ldots,0 \right)^{T} \rightarrow \ldots \rightarrow \left( 2^{n - 1},2^{n - 2},2^{n - 3},\ldots,1 \right)^{T}是一致的(n=5

RNN 的“反向传播算法”

简洁起见,我们采用上一章第一张图所示的那个朴素网络结构:

然后做出如下符号约定:

  • \phi作为隐藏层的激活函数
  • \varphi作为输出层的变换函数
  • L_{t} = L_{t}\left( o_{t},y_{t} \right)作为模型的损失函数,其中标签y_{t}是一个 one-hot 向量;由于 RNN 处理的通常是序列数据、所以在接受完序列中所有样本后再统一计算损失是合理的,此时模型的总损失可以表示为(假设输入序列长度为n):L = \sum_{t = 1}^{n}L_{t}

为了更清晰地表明各个配置,我们可以整理出如下图所示的结构:


易知o_{t} = \varphi\left( \text{Vs}_{t} \right) = \varphi\left( \text{Vϕ}\left( Ux_{t} + Ws_{t - 1} \right) \right),其中s_{0} = \mathbf{0 =}\left( 0,0,\ldots,0 \right)^{T}。令:

o_{t}^{*} = \text{Vs}_{t},\ \ s_{t}^{*} = Ux_{t} + Ws_{t - 1}

则有:

o_{t} = \varphi\left( o_{t}^{*} \right),\ \ s_{t} = \phi(s_{t}^{*})

从而(注:统一使用“*”表示 element wise 乘法,使用“\times”表示矩阵乘法):

\frac{\partial L_{t}}{\partial o_{t}^{*}} = \frac{\partial L_{t}}{\partial o_{t}}*\frac{\partial o_{t}}{\partial o_{t}^{*}} = \frac{\partial L_{t}}{\partial o_{t}}*\varphi^{'}\left( o_{t}^{*} \right)

\frac{\partial L_{t}}{\partial V} = \frac{\partial L_{t}}{\partial Vs_{t}} \times \frac{\partial Vs_{t}}{\partial V} = \left( \frac{\partial L_{t}}{\partial o_{t}}*\varphi^{'}\left( o_{t}^{*} \right) \right) \times s_{t}^{T}

可见对矩阵V的分析过程即为普通的反向传播算法,相对而言比较平凡。由L = \sum_{t = 1}^{n}L_{t}可知,它的总梯度可以表示为:

\frac{\partial L}{\partial V} = \sum_{t = 1}^{n}{\left( \frac{\partial L_{t}}{\partial o_{t}}*\varphi^{'}\left( o_{t}^{*} \right) \right) \times s_{t}^{T}}

而事实上,RNN 的 BP 算法的主要难点在于它 State 之间的通信,亦即梯度除了按照空间结构传播(o_{t} \rightarrow s_{t} \rightarrow x_{t})以外,还得沿着时间通道传播(s_{t} \rightarrow s_{t - 1} \rightarrow \ldots \rightarrow s_{1}),这导致我们比较难将相应 RNN 的 BP 算法写成一个统一的形式(回想之前的“前向传导算法”)。为此,我们可以采用“循环”的方法来计算各个梯度

由于是反向传播算法,所以t应从n开始降序循环至 1,在此期间(若需要初始化、则初始化为 0 向量或 0 矩阵):

  • 计算时间通道上的“局部梯度” :
    \begin{align} \frac{\partial L_{t}}{\partial s_{t}^{*}} &= \frac{\partial s_{t}}{\partial s_{t}^{*}}*\left( \frac{\partial s_{t}^{T}V^{T}}{\partial s_{t}} \times \frac{\partial L_{t}}{\partial Vs_{t}} \right) = \phi'(s_t^*)*\left[V^{T} \times \left( \frac{\partial L_{t}}{\partial o_{t}}*\varphi^{'}\left( o_{t}^{*} \right) \right)\right] \\ \frac{\partial L_{t}}{\partial s_{k - 1}^{*}} &= \frac{\partial s_{k}^{*}}{\partial s_{k - 1}^{*}} \times \frac{\partial L_{t}}{\partial s_{k}^{*}} = \phi^{'}\left( s_{k - 1}^{*} \right) * \left( W^{T} \times \frac{\partial L_{t}}{\partial s_{k}^{*}} \right),\ \ (k = 1,\ldots,t) \end{align}
  • 利用时间通道上的“局部梯度”计算UW的梯度:
    \begin{align} \frac{\partial L_{t}}{\partial U} &= \sum_{k = 1}^{t}{\frac{\partial L_{t}}{\partial s_{k}^{*}} \times \frac{\partial s_{k}^{*}}{\partial U}} = \sum_{k = 1}^{t}{\frac{\partial L_{t}}{\partial s_{k}^{*}} \times x_{k}^{T}} \\ \frac{\partial L_{t}}{\partial W} &= \sum_{k = 1}^{t}{\frac{\partial L_{t}}{\partial s_{k}^{*}} \times \frac{\partial s_{k}^{*}}{\partial W}} = \sum_{k = 1}^{t}{\frac{\partial L_{t}}{\partial s_{k}^{*}} \times s_{k - 1}^{T}} \end{align}

以上即为 RNN 反向传播算法的所有推导,它比 NN 的 BP 算法要繁复不少。事实上,像这种需要把梯度沿时间通道传播的 BP 算法是有一个专门的名词来描述的——Back Propagation Through Time(常简称为 BPTT,可译为“时序反向传播算法”)

不妨举一个具体的栗子来加深理解,假设:

  • 激活函数\phi为 Sigmoid 函数
  • 变换函数\varphi为 Softmax 函数
  • 损失函数L_{t}为 Cross Entropy(感谢评论区 指出这里的错误):L_{t}\left( o_{t},y_{t} \right) = -\left[y_{t}\log o_{t}+(1-y_t)\log(1-o_t)\right]

由 NN 处的讨论可知这是一个非常经典、有效的配置,其中:

\frac{\partial L_{t}}{\partial o_{t}}*\varphi^{'}\left( o_{t}^{*} \right) = o_{t} - y_{t}

\phi^{'}\left( s_{t}^{*} \right) = \phi\left( s_{t}^{*} \right)*\left( 1 - \phi\left( s_{t}^{*} \right) \right) = s_{t}*(1 - s_{t})

从而

\frac{\partial L}{\partial V} = \sum_{t = 1}^{n}{\left( o_{t} - y_{t} \right) \times s_{t}^{T}}

tn开始降序循环至 1 的期间中,各个“局部梯度”为:

\begin{align} \frac{\partial L_{t}}{\partial s_{t}^{*}} &= V^{T} \times \left( \frac{\partial L_{t}}{\partial o_{t}}*\varphi^{'}\left( o_{t}^{*} \right) \right) = \left[s_t*(1-s_t)\right]*\left[ V^{T} \times (o_{t} - y_{t})\right] \\ \frac{\partial L_{t}}{\partial s_{k - 1}^{*}} &= W^{T} \times \left( \frac{\partial L_{t}}{\partial s_{k}^{*}}*\phi^{'}\left( s_{k - 1}^{*} \right) \right) = [s_{k - 1}*\left( 1 - s_{k - 1} \right)] * \left(W^{T} \times \frac{\partial L_{t}}{\partial s_{k}^{*}} \right),\ \ (k = 1,\ldots,t) \end{align}

由此可算出如下相应梯度:

\begin{align} \frac{\partial L_{t}}{\partial U} &= \sum_{k = 1}^{t}{\frac{\partial L_{t}}{\partial s_{k}^{*}} \times x_{k}^{T}} \\ \frac{\partial L_{t}}{\partial W} &= \sum_{k = 1}^{t}{\frac{\partial L_{t}}{\partial s_{k}^{*}} \times s_{k - 1}^{T}} \end{align}

可以看到形式相当简洁,所以我们完全可以比较轻易地写出相应实现:

class RNN2(RNN1):
    # 定义 Sigmoid 激活函数
    def activate(self, x):
        return 1 / (1 + np.exp(-x))

    # 定义 Softmax 变换函数
    def transform(self, x):
        safe_exp = np.exp(x - np.max(x))
        return safe_exp / np.sum(safe_exp)

    def bptt(self, x, y):
        x, y, n = np.asarray(x), np.asarray(y), len(y)
        # 获得各个输出,同时计算好各个 State
        o = self.run(x)
        # 照着公式敲即可 ( σ'ω')σ
        dis = o - y
        dv = dis.T.dot(self._states[:-1])
        du = np.zeros_like(self._u)
        dw = np.zeros_like(self._w)
        for t in range(n-1, -1, -1):
            st = self._states[t]
            ds = self._v.T.dot(dis[t]) * st * (1 - st)
            # 这里额外设定了最多往回看 10 步
            for bptt_step in range(t, max(-1, t-10), -1):
                du += np.outer(ds, x[bptt_step])
                dw += np.outer(ds, self._states[bptt_step-1])
                st = self._states[bptt_step-1]
                ds = self._w.T.dot(ds) * st * (1 - st)
        return du, dv, dw

    def loss(self, x, y):
        o = self.run(x)
        return np.sum(
            -y * np.log(np.maximum(o, 1e-12)) -
            (1 - y) * np.log(np.maximum(1 - o, 1e-12))
        )

注意我们设定了在每次沿时间通道反向传播时、最多往回看 10 步,这是因为我们实现的这种朴素 RNN 的梯度存在着一些不良性质,我们在下一节中马上就会进行相关的说明

指数级梯度所带来的问题

注意到由于 RNN 需要沿时间通道进行反向传播,其相应的“局部梯度”为:

\frac{\partial L_{t}}{\partial s_{k - 1}^{*}} = [s_{k - 1}*\left( 1 - s_{k - 1} \right)] * \left(W^{T} \times \frac{\partial L_{t}}{\partial s_{k}^{*}} \right)

注意到式中的每个局部梯度\frac{\partial L_{t}}{\partial s_{k}^{*}}都会“携带”一个W矩阵和一个s_{k}的 Sigmoid 系激活函数所对应的梯度s_{k}*\left( 1 - s_{k} \right),这意味着局部梯度受W和各个激活函数的梯度的影响是指数级的。姑且不考虑W而单看激活函数的梯度,回忆我们之前在 NN 处讲过的梯度问题,这里的这种指数级梯度的表现和彼时深层网络梯度的表现是几乎同理的(事实上 RNN 的时间通道长得确实很像一个深层网络)——当输入趋近于两端时,激活函数的梯度会随着传播而迅速弥散,这就是 RNN 中所谓的“梯度消失(The Vanishing Gradient)”问题。是故我们在上一小节实现 RNN 时规定在沿时间通道反向传播时最多只往回看 10 步,这是因为再往下看也没有太大意义了(可以大概地类比于多于 10 层的、以 Sigmoid 系函数作为激活函数的神经网络)(以下纯属开脑洞:这么说的话是不是能在时间通道里面传递残差然后弄一个 Residual RNN 呢……)

这当然是非常令人沮丧的结果,要知道 RNN 的一大好处就在于它能利用上历史的信息,然而梯度消失却告诉我们 RNN 能够利用的历史信息十分有限。所以针对该问题作出优化是非常有必要的,解决方案大体上分两种:

  • 选用更好的激活函数
  • 改进 State 的传递方式

第二点是 LSTMs 等特殊 RNN 的做法,这里就主要说说第一点——如何选用更好的激活函数。由 NN、CNN 处的讨论不难想到,用 ReLU 作为激活函数很有可能是个不错的选择;不过由于梯度是指数级的这一点不会改变,此时我们可能就会面临另一个问题:“梯度爆炸(The Exploding Gradient)”(注:不是说 Sigmoid 系函数就不会引发梯度爆炸、因为当矩阵W的元素很大时同样会爆炸,只是相对而言更容易引发梯度消失而已)。不过相比起梯度消失问题来讲,梯度爆炸相对而言要显得更“友好”一些,这是因为:

  • 梯度爆炸一旦发生,是会迅速反映到结果上来的(比如一堆数变成了 NaN)
  • 梯度爆炸可以通过简单的设定阈值来得到改善

而梯度消失相比之下,既难以直接从结果看出、又没有特别平凡的解决方案。现有的较常用的方法为调整参数的初值、进行适当的正则化、使用 ReLU(需要小心梯度爆炸)等等

关于为何 LSTMs 能够解决梯度消失,直观上来说就是上方时间通道是简单的线性组合、从而使得梯度不再是指数级的。详细的推导可以参见各种论文(比如说这篇),我就不在这里献丑了 ( σ'ω')σ

以上就大致地说了说 RNN 的 BPTT 算法,主要要注意的其实就是时间通道上的 BP 算法。如果把时间通道看成一个神经网络的话,运用局部梯度来反向传播其实相当自然

猜你喜欢

转载自blog.csdn.net/yc1203968305/article/details/80243840