【深度学习】循环神经网络RNN和长短时记忆网络LSTM(超详细原理推导和解释)

写在前面

因为近期要做一些金融股票预测相关的项目课题,最近两天着手看了一下 R N N \rm RNN RNN L S T M \rm LSTM LSTM,以前也零零碎碎的看过几次,但都被它庞杂的网络结构和公式吓退了,这次终于静下心来研究了一下这两个网络,并且也亲自用手推导了一下它们的反向传播过程,也辛辛苦苦亲手打了很多公式、画了很多图,花了很长时间,绝对超级详细,只要静下心来,我相信谁都能看懂它们的工作机制尤其是是 L S T M \rm LSTM LSTM,最终写下这片文章,一是为更多的人透析其原理提供便利,二是方便自己回过头来复习。里面的内容是自己的理解并结合相关参考资料做的一些笔记,走过路过的朋友们一键三连互关走起来,互相学习,共同促进!再次感谢大佬们的文章和视频!

循环神经网络RNN

引言

Recurrent Neural Network ( R ecurrent N eural N network ) (\rm Recurrent \ Neural \ Network)( Recurrent Neural Network ) refers to a structure that recurs over time . _ _ _ _ _ _ _ _ _ _   _ It is widely used in many fields such as natural language processing and speech image processing. It is a neural network model for processing sequence data. The traditional neural networkconsists of an input layer, a hidden layer, and an output layer. The output is controlled by an activation function, and the layers are connected by weights. The ultimate goal of neural network training is to learn a set of weights to process new data to meet our requirements. Such as Convolutional Neural Network( CNN ) (\rm CNN)( C N N ) , the output calculated by forward propagation only considers the influence of the previous input and does not consider the influence of input at other times.

循环神经网络与传统神经网络最大的不同就是,拥有"记忆功能"——可以用来处理基于序列的数据。例如,你要预测一个句子的下一个单词是什么,一般需要用到前面的单词,因为句子中的单词是并不是脱离上下文环境而存在的,我们必须要考虑其所在的上下文环境。也就是说, R N N \rm RNN RNN处理的数据当前时刻的输出与前面时刻的输出也有关系

网络结构

R N N \rm RNN RNN的记忆功能具体表现为网络会对前面的信息进行记忆并应用于当前输出的计算中。因此隐藏层之间产生了连接,并且隐藏层的输入不仅包括输入层的输入值,还包括上一时刻隐藏层的输出值

insert image description here

从上述图中可以看出权重矩阵 W W W作用于每一个隐藏层的神经元。其中 x t x_t xt表示当前 t t t的输入层输入值, s t s_t st表示时刻 t t t的隐藏层输出值(下文中用 h t h_t ht表示隐藏层的输出值), o t o_t ot则表示时刻 t t Output value of t , UUU represents the weight of the input layer,VVV represents the weight of the output layer.

Weight sharing : It can be seen from the network structure that U , V , WU,V,WU,V,W is a set of parameters shared by all features. Its advantage is that in the face of different inputs, different corresponding results can be learned; the number of training parameters is reduced; the length ofinput and output data can be different in different examples.

forward propagation

insert image description here

From the network structure we can see that RNN \rm RNNThe main calculation parameters of RNN are the output value of the hidden layer and the value of the output layer.
ht = f ( U xt + W ht − 1 ) ot = g ( V ht ) \begin{aligned} h_t &= f(Ux_t + Wh_{t - 1}) \\ o_t &= g(Vh_t) \end{ aligned}htot=f(Uxt+Wht1)=g(Vht)
In the above calculation process, the functions f , gf, gf,g is the activation function,fff generally takestanh ⁡ \tanhtanh function,ggg generally takessoftmax \rm softmaxsoft max function . _ _ _ _

通过反复代入
o t = g ( V h t ) = g V ( f ( U x t + W h t − 1 ) ) = g V ( f ( U x t + W f ( U x t − 1 + W h t − 2 ) ) ) = g V ( f ( U x t + W f ( U x t − 1 + W f ( U x t − 2 + W h t − 3 ) ) ) ) = ⋯ \begin{aligned} o_t &= g(V h_t) \\ &= gV(f(Ux_t + Wh_{t - 1})) \\ &= gV(f(Ux_t + Wf(Ux_{t -1} + Wh_{t - 2}))) \\ &= gV(f(Ux_t + Wf(Ux_{t -1} + Wf(Ux_{t - 2} + Wh_{t-3})))) \\ &= \cdots \end{aligned} ot=g(Vht)=gV(f(Uxt+Wht1))=gV(f(Uxt+Wf(Uxt1+Wht2)))=gV(f(Uxt+Wf(Uxt1+Wf(Uxt2+Wht3))))=
From this it can be seen that ttThe output at time t is the same as xt , xt − 1 , xt − 2 , xt − 3 , ⋯ x_t, x_{t-1}, x_{t-2},x_{t-3}, \cdotsxt,xt1,xt2,xt3,... are all related.

The bias value bb is omitted hereb , generally this value needs to be added.

loss function

A single time step loss can be defined according to the task type, the more widely used is the cross entropy loss ( CE ) (\rm CE)(CE),即
L C E = − ( y t ln ⁡ o t + ( 1 − y t ) ln ⁡ ( 1 − o t ) ) L_{\rm CE} = - (y_t \ln o_t + (1 - y_t) \ln (1 - o_t)) LCE=(ytlnot+(1yt)ln(1ot))
整个时间序列的损失就是单个时间步的损失之和,即
L = ∑ t L C E L = \sum_t L_{\rm CE} L=tLCE
其中, y t y_t yt表示时刻 t t t的真实标签值, o t o_t ot表示时刻 t t t模型预测输出值。

反向传播

从前向传播的过程中,可以看出只需要对三个权值 U , V , W U,V,W U,V,W进行优化即可,因此分别对其求梯度。

参数优化

  • 对参数 V V V求偏导
    ∂ L ∂ V = ∑ t ∂ L t ∂ o t ∂ o t ∂ V \frac{\partial{L}}{\partial{V}} = \sum_{t} \frac{\partial{L_t}}{\partial{o_t}} \frac{\partial{o_t}} {\partial{V}} VL=totLtVot
    其中还要对复合函数(激活函数)求导。

  • 对参数 W W W求偏导

    对该参数涉及到之前时刻的信息,求导相对比较复杂,因此假设在 t = 3 t= 3 t=3时刻,利用前面时刻的数据对 W W W求导,即

    ∂ L 3 ∂ W = ∂ L 3 ∂ o 3 ∂ o 3 ∂ h 3 ( ∂ h 3 ∂ W + ∂ h 3 ∂ h 2 ∂ h 2 ∂ W + ∂ h 3 ∂ h 2 ∂ h 2 ∂ h 1 ∂ h 1 ∂ W ) = ∂ L t ∂ o t ∂ o t ∂ h t ∑ j = 1 t [ ( ∏ i = j + 1 t ∂ h i ∂ h i − 1 ) ∂ h j ∂ W ] \begin{aligned} \frac{\partial{L_3}}{\partial{W}} &= \frac{\partial{L_3}}{\partial{o_3}} \frac{\partial{o_3}}{\partial{h_3}} (\frac{\partial{h_3}}{\partial{W}} + \frac{\partial{h_3}}{\partial{h_2}} \frac{\partial{h_2}}{\partial{W}} + \frac{\partial{h_3}}{\partial{h_2}} \frac{\partial{h_2}}{\partial{h_1}} \frac{\partial{h_1}}{\partial{W}}) \\ &= \frac{\partial{L_t}}{\partial{o_t}} \frac{\partial{o_t}}{\partial{h_t}} \sum_{j=1}^t \left[ \left(\prod_{i = j+1}^t \frac{\partial{h_{i}}}{\partial{h_{i - 1}}}\right) \frac{\partial{h_j}}{\partial{W}}\right] \end{aligned} WL3=o3L3h3o3(Wh3+h2h3Wh2+h2h3h1h2Wh1)=otLthtotj=1t[(i=j+1thi1hi)Whj]

  • For parameter UUU for partial derivative

    Assuming ditto. Therefore there

    ∂ L 3 ∂ W = ∂ L 3 ∂ o 3 ∂ o 3 ∂ h 3 ( ∂ h 3 ∂ U + ∂ h 3 ∂ h 2 ∂ h 2 ∂ U + ∂ h 3 ∂ h 2 ∂ h 2 ∂ h 1 ∂ h 1 ∂ U ) = ∂ L t ∂ o t ∂ o t ∂ h t ∑ j = 1 t [ ( ∏ i = j + 1 t ∂ h i ∂ h i − 1 ) ∂ h j ∂ U ] \begin{aligned} \frac{\partial{L_3}}{\partial{W}} &= \frac{\partial{L_3}}{\partial{o_3}} \frac{\partial{o_3}}{\partial{h_3}} (\frac{\partial{h_3}}{\partial{U}} + \frac{\partial{h_3}}{\partial{h_2}} \frac{\partial{h_2}}{\partial{U}} + \frac{\partial{h_3}}{\partial{h_2}} \frac{\partial{h_2}}{\partial{h_1}} \frac{\partial{h_1}}{\partial{U}}) \\ &= \frac{\partial{L_t}}{\partial{o_t}} \frac{\partial{o_t}}{\partial{h_t}} \sum_{j=1}^t \left[ \left(\prod_{i = j+1}^t \frac{\partial{h_{i}}}{\partial{h_{i - 1}}}\right) \frac{\partial{h_j}}{\partial{U}}\right] \end{aligned} WL3=o3L3h3o3(Uh3+h2h3Uh2+h2h3h1h2Uh1)=otLthtotj=1t[(i=j+1thi1hi)Uhj]

梯度消失或爆炸

在参数优化过程中,随着时间的不断累积,对参数 U , W U,W U,W的优化过程就会出现梯度累乘,此间要对函数 f f f求导,因此有
∏ i = j + 1 t ∂ h i ∂ h i − 1 = ∏ i = j + 1 t f ′ ⋅ W \prod_{i = j+1}^t \frac{\partial{h_{i}}}{\partial{h_{i - 1}}} = \prod_{i = j+1}^t f' \cdot W i=j+1thi1hi=i=j+1tfW
就会导致激活函数的累乘,由于激活函数通常为 tanh ⁡ \tanh tanh或者 s i g m o i d \rm sigmoid sigmoid,其函数和导数图像为

insert image description here
insert image description here

由此可知,不管使用哪种激活函数,其导数的值总不会超过 1 1 1,累乘之后就会出现梯度消失的问题。因此,只有两种解决办法,其一就是使用更好的激活函数(比如 R e L U \rm ReLU ReLU),其二就是改变网络的传播结构,也就是后来提出的 L S T M \rm LSTM LSTM

长短时记忆网络LSTM

网络结构以及与RNN的区别

长短时记忆网络 ( L o n g   s h o r t − t e r m   m e m o r y ) \rm (Long \ short-term \ memory) (Long shortterm memory),是一种特殊的循环神经网络,它通过特殊的传播结构设计来避免长期依赖,从而缓解 R N N \rm RNN RNN梯度消失或者爆炸的情形,它在更长的序列中能有更好的表现 L S T M \rm LSTM LSTM和循环神经网络有基本相同的网络结构,唯一不同的就是单层的前向传播网络模块,单一重复模块有 4 4 4个网络层,以一种特殊的方式进行交互。本质上的不同就是 L S T M \rm LSTM LSTM通过记忆细胞选择性记忆重要信息,过滤掉不重要的信息,减轻记忆负担,而 R N N \rm RNN RNN则记住所有信息,增加了网络的负担。

insert image description here

核心思想

L S T M \rm LSTM LSTM的关键就是细胞状态,也就是下图中的 C t C_t Ct,细胞状态在相当于一个传送带上传送,只有少量的线性交互,所以信息很难流传或者长时间记忆而不发生改变。它主要就是通过称为"门"的结构来控制信号传播以调节细胞状态,从而实现在细胞状态中添加或者删减信息,其中,"门"就是上述结构中的一些相乘或者相加的结构,通过激活函数来使信息选择性通过。

insert image description here

前向传播

遗忘门

insert image description here

上图所示的是 L S T M \rm LSTM LSTM的第一阶段,也就是"忘记阶段",具体来说,就是通过 h t − 1 h_{t - 1} ht1 x t x_t xt计算出 f t ( f o r g e t ) f_t(\rm forget) ft(forget)作为忘记门控,忘记来自细胞状态 C t − 1 C_{t-1} Ct1中不重要的信息,即对于来此 C t − 1 C_{t-1} Ct1状态的每个数输出 0 0 0 1 1 1之间的数, 1 1 1表示完全记住, 0 0 0表示完全忘记。

语言模型的例子中,基于已经看到的预测下一个词。在这个问题中,细胞状态可能包含当前主语的性别,因此正确的代词可以被选择出来。当看到新的主语,希望忘记旧的主语

f t f_t ft具体计算方式如下
f t = σ ( W x f x t + W h f h t − 1 + b f ) f_t = \sigma(W_{xf} x_t +W_{hf} h_{t - 1} +b_f) ft=σ(Wxfxt+Whfht1+bf)

输入门

上图所示的是 L S T M \rm LSTM LSTM第二阶段,也就是"输入阶段",就是决定在细胞状态中增加什么信息。其中有两部分,第一部分就是通过 s i g m o i d \rm sigmoid s i g m o i d function todetermine what new information to add, the second part is throughtanh ⁡ \tanhtanh function to create a candidate value vectorC t ~ \tilde{C_t}Ct~

具体计算方式如下
i t = σ ( W x i x t + W h i h t − 1 + b i ) C t ~ = tanh ⁡ ( W x C x t + W h C h t − 1 + b C ) \begin{aligned} i_t &= \sigma(W_{xi}x_t + W_{hi} h_{t-1} +b_i) \\ \tilde{C_t} &= \tanh(W_{xC} x_t + W_{hC} h_{t-1} +b_C) \end{aligned} itCt~=s ( Wxixt+Whiht1+bi)=fishy ( WxCxt+WhCht1+bC)

状态更新

insert image description here

将旧的细胞状态 C t − 1 C_{t - 1} Ct1更新为 C t C_t Ct,前面的阶段我们已经确定了要遗忘、记住和添加的信息,现在就是实际去完成这个操作。

具体计算方式如下
C t = f t ⊙ C t − 1 + i t ⊙ C t ~ C_t = f_t \odot C_{t-1} + i_t \odot \tilde{C_t} Ct=ftCt1+itCt~
f t f_t ft乘以旧的状态 C t − 1 C_{t-1} Ct1来忘记决定忘记的信息,再加上 i t ⊙ C t ~ i_t \odot \tilde{C_t} itCt~,它是新的候选值,根据我们决定更新每个状态的程度进行变化。

输出门

insert image description here

如上图所示,该阶段会决定输出什么值,输出会基于当前细胞状态 C t C_t Ct,也是一个过滤之后的版本。首先通过 s i g m o i d \rm sigmoid sigmoid函数决定当前状态下的哪些输入需要输出,然后将当前细胞状态 C t C_t Ct通过 tanh ⁡ \tanh tanh函数压缩到 − 1 -1 1 1 1 1之间,并将它和 o t o_t ot进行相乘,最终输出我们分。

具体计算方式如下
o t = σ ( W x o x t + W h o h t − 1 + b o ) h t = o t ⊙ tanh ⁡ ( C t ) \begin{aligned} o_t &= \sigma(W_{xo} x_t + W_{ho} h_{t - 1} + b_o) \\ h_t &= o_t \odot \tanh(C_t) \end{aligned} otht=σ(Wxoxt+Whoht1+bo)=ottanh(Ct)

反向传播

首先将前向传播的表达式罗列如下:
{ f t = σ ( W x f x t + W h f h t − 1 + b f ) i t = σ ( W x i x t + W h i h t − 1 + b i ) C ~ t = tanh ⁡ ( W x C ~ x t + W h C ~ h t − 1 + b C ~ ) C t = f t ⊙ C t − 1 + i t ⊙ C ~ t o t = σ ( W x o x t + W h o h t − 1 + b o ) h t = o t ⊙ tanh ⁡ ( C t ) y t = W y h t + b y \begin{cases} f_t &= \sigma(W_{xf} x_t +W_{hf}h_{t - 1} +b_f) \\ i_t &= \sigma(W_{xi}x_t + W_{hi} h_{t-1} +b_i) \\ \tilde{C}_t &= \tanh(W_{x\tilde{C}} x_t + W_{h\tilde{C}} h_{t-1} +b_{\tilde{C}}) \\ C_t &= f_t \odot C_{t-1} + i_t \odot \tilde{C}_t \\ o_t &= \sigma(W_{xo} x_t + W_{ho} h_{t - 1} + b_o) \\ h_t &= o_t \odot \tanh(C_t) \\ y_t &= W_y h_t +b_y \end{cases} ftitC~tCtothtyt=s ( Wxfxt+Whfht1+bf)=s ( Wxixt+Whiht1+bi)=fishy ( WxC~xt+WhC~ht1+bC~)=ftCt1+itC~t=s ( Wxoxt+Wh oht1+bo)=otfishy ( _t)=Wyht+by

Parameter optimization

It can be seen from the process of forward propagation that the parameters that need to be optimized are only W xf , W hf , W xi , W hi , W x C , W h CW xo , W ho , W y W_{xf},W_{hf }, W_{xi},W_{hi}, W_{xC},W_{hC} W_{xo},W_{ho},WyWxf,Whf,Wxi,Whi,WxC,WhCWxo,Wh o,Wy,因此基于它们求导。参照 R N N \rm RNN RNN的反向传播,假设在时刻 t = 3 t= 3 t=3,利用之前时刻的数据对上述优化参数中的 W x f W_{xf} Wxf
∂ L 3 ∂ W x f ( 3 ) + ∂ L 3 ∂ W x f ( 2 ) + ∂ L 3 ∂ W x f ( 1 ) \frac{\partial{L_3}}{\partial{W_{xf}^{(3)}}} + \frac{\partial{L_3}}{\partial{W_{xf}^{(2)}}} + \frac{\partial{L_3}}{\partial{W_{xf}^{(1)}}} Wxf(3)L3+Wxf(2)L3+Wxf(1)L3

∂ L 3 ∂ W x f ( 3 ) = ∂ L 3 ∂ y 3 ∂ y 3 ∂ h 3 ∂ h 3 ∂ C 3 ∂ C 3 ∂ f 3 ∂ f 3 ∂ W x f ( 3 ) \frac{\partial{L_3}}{\partial{W_{xf}^{(3)}}} = \frac{\partial{L_3}}{\partial{y_3}} \frac{\partial{y_3}}{\partial{h_3}} \frac{\partial{h_3}}{\partial{C_3}} \frac{\partial{C_3}}{\partial{f_3}} \frac{\partial{f_3}}{\partial{W_{xf}^{(3)}}} \\ Wxf(3)L3=y3L3h3y3C3h3f3C3Wxf(3)f3
∂ L 3 ∂ W x f ( 2 ) = ∂ L 3 ∂ y 3 ∂ y 3 ∂ h 3 { ∂ h 3 ∂ o 3 ∂ o 3 ∂ h 2 ∂ h 2 ∂ C 2 ∂ h 3 ∂ C 3 { ∂ C 3 ∂ C 2 ∂ C 3 ∂ f 3 ∂ f 3 ∂ h 2 ∂ h 2 ∂ C 2 ∂ C 3 ∂ i 3 ∂ i 3 ∂ h 2 ∂ h 2 ∂ C 2 ∂ C 3 ∂ C ~ 3 ∂ C ~ 3 ∂ h 2 ∂ h 2 ∂ C 2 } } ∂ C 2 ∂ f 2 ∂ f 2 ∂ W x f ( 2 ) = ∂ L 3 ∂ y 3 ∂ y 3 ∂ h 3 { ∂ h 3 ∂ o 3 ∂ o 3 ∂ h 2 ∂ h 2 ∂ C 2 ∂ h 3 ∂ C 3 ∂ C 3 ∂ C 2 } ∂ C 2 ∂ f 2 ∂ f 2 ∂ W x f ( 2 ) \begin{aligned} \frac{\partial{L_3}}{\partial{W_{xf}^{(2)}}} &= \frac{\partial{L_3}}{\partial{y_3}} \frac{\partial{y_3}}{\partial{h_3}} \left\{ \begin{array}{l} \frac{\partial{h_3}}{\partial{o_3}} \frac{\partial{o_3}}{\partial{h_2}} \frac{\partial{h_2}}{\partial{C_2}} \\ \frac{\partial{h_3}}{\partial{C_3}} \left\{ \begin{array}{l} \color{red} { \frac{\partial{C_3}}{\partial{C_2}} }\\ \color{red}{ \frac{\partial{C_3}}{\partial{f_3}} \frac{\partial{f_3}}{\partial{h_2}} \frac{\partial{h_2}}{\partial{C_2}} }\\ \color{red}{ \frac{\partial{C_3}}{\partial{i_3}} \frac{\partial{i_3}}{\partial{h_2}} \frac{\partial{h_2}}{\partial{C_2}} }\\ \color{red}{ \frac{\partial{C_3}}{\partial{\tilde{C}_3}} \frac{\partial{\tilde{C}_3}}{\partial{h_2}} \frac{\partial{h_2}}{\partial{C_2}} } \\ \end{array} \right \} \end{array} \right\} \frac{\partial{C_2}}{\partial{f_2}} \frac{\partial{f_2}}{\partial{W_{xf}^{(2)}}} \\ &= \frac{\partial{L_3}}{\partial{y_3}} \frac{\partial{y_3}}{\partial{h_3}} \left\{ \begin{array}{l} \frac{\partial{h_3}}{\partial{o_3}} \frac{\partial{o_3}}{\partial{h_2}} \frac{\partial{h_2}}{\partial{C_2}} \\ \frac{\partial{h_3}}{\partial{C_3}} \color{red} { \frac{\partial{C_3}}{\partial{C_2}} }\\ \end{array} \right\} \frac{\partial{C_2}}{\partial{f_2}} \frac{\partial{f_2}}{\partial{W_{xf}^{(2)}}} \end{aligned} Wxf(2)L3=y3L3h3y3o3h3h2o3C2h2C3h3C2C3f3C3h2f3C2h2i3C3h2i3C2h2C~3C3h2C~3C2h2f2C2Wxf(2)f2=y3L3h3y3{ o3h3h2o3C2h2C3h3C2C3}f2C2Wxf(2)f2

t 3 → t 1 = { t 3 → t 2 : h 3 → { o 3 → h 2 C 3 → C 2 C 3 → f 3 → h 2 C 3 → i 3 → h 2 C 3 → C ~ 3 → h 2 } t 2 → t 1 { C 2 → { f 2 → h 1 → C 1 → f 1 i 2 → h 1 → C 1 → f 1 C ~ 2 → h 1 → C 1 → f 1 C 1 → f 1 } h 2 → { o 2 → h 1 → C 1 → f 1 C 2 → C 1 → f 1 C 2 → f 2 → h 1 → C 1 → f 1 C 2 → i 2 → h 1 → C 1 → f 1 C 2 → C ~ 2 → h 1 → C 1 → f 1 } } } t 3 → t 2 → t 1 : t o t a l   24   p a t h s t_3 \to t_1 = \left\{\begin{array}{l} t_3 \to t_2:h_3 \to \left\{\begin{array}{l} o_3 \to h_2 \\ C_3 \to C_2 \\ C_3 \to f_3 \to h_2 \\ C_3 \to i_3 \to h_2 \\ C_3 \to \tilde{C}_3 \to h_2 \\ \end{array}\right\} \\ t_2 \to t_1 \left\{\begin{array}{l} C_2 \to \left\{\begin{array}{l} f_2 \to h_1 \to C_1 \to f_1 \\ i_2 \to h_1 \to C_1 \to f_1 \\ \tilde{C}_2 \to h_1 \to C_1 \to f_1 \\ C_1 \to f_1 \end{array}\right\} \\ h_2 \to \left\{\begin{array}{l} o_2 \to h_1 \to C_1 \to f_1 \\ C_2 \to C_1 \to f_1 \\ C_2 \to f_2 \to h_1 \to C_1 \to f_1 \\ C_2 \to i_2 \to h_1 \to C_1 \to f_1 \\ C_2 \to \tilde{C}_2 \to h_1 \to C_1 \to f_1 \\ \end{array}\right\} \\ \end{array}\right\} \\ \end{array}\right\} \rm t_3 \to t_2 \to t_1 :total \, 24 \, paths t3t1=t3t2:h3o3h2C3C2C3f3h2C3i3h2C3C~3h2t2t1C2f2h1C1f1i2h1C1f1C~2h1C1f1C1f1h2o2h1C1f1C2C1f1C2f2h1C1f1C2i2h1C1f1C2C~2h1C1f1t3t2t1:total24paths

∂ L 3 ∂ W x f ( 1 ) = ∂ L 3 ∂ y 3 ∂ y 3 ∂ h 3 { t 3 → t 2 { ∂ h 3 ∂ o 3 ∂ o 3 ∂ h 2 ∂ h 3 ∂ C 3 ∂ C 3 ∂ C 2 ∂ h 3 ∂ C 3 ∂ C 3 ∂ f 3 ∂ f 3 ∂ h 2 ∂ h 3 ∂ C 3 ∂ C 3 ∂ i 3 ∂ i 3 ∂ h 2 ∂ h 3 ∂ C 3 ∂ C 3 ∂ C ~ 3 ∂ C ~ 3 ∂ h 2 } t 2 → t 1 { { ∂ C 2 ∂ f 2 ∂ f 2 ∂ h 1 ∂ h 1 ∂ C 1 ∂ C 2 ∂ i 2 ∂ i 2 ∂ h 1 ∂ h 1 ∂ C 1 ∂ C 2 ∂ C ~ 2 ∂ C ~ 2 ∂ h 1 ∂ h 1 ∂ C 1 ∂ C 2 ∂ C 1 } { ∂ h 2 ∂ o 2 ∂ o 2 ∂ h 1 ∂ h 1 ∂ C 1 ∂ h 2 ∂ C 2 { ∂ C 2 ∂ C 1 ∂ C 2 ∂ f 2 ∂ f 2 ∂ C 1 ∂ C 2 ∂ i 2 ∂ i 2 ∂ h 1 ∂ h 1 ∂ C 1 ∂ C 2 ∂ C ~ 2 ∂ C ~ 2 ∂ h 1 ∂ h 1 ∂ C 1 } } } ∂ C 1 ∂ f 1 } ∂ f 1 ∂ W x f ( 1 ) = ∂ L 3 ∂ y 3 ∂ y 3 ∂ h 3 { t 3 → t 2 { ∂ h 3 ∂ o 3 ∂ o 3 ∂ h 2 ∂ h 3 ∂ C 3 ∂ C 3 ∂ C 2 ∂ h 3 ∂ C 3 ∂ C 3 ∂ f 3 ∂ f 3 ∂ h 2 ∂ h 3 ∂ C 3 ∂ C 3 ∂ i 3 ∂ i 3 ∂ h 2 ∂ h 3 ∂ C 3 ∂ C 3 ∂ C ~ 3 ∂ C ~ 3 ∂ h 2 } t 2 → t 1 { ∂ C 2 ∂ C 1 { ∂ h 2 ∂ o 2 ∂ o 2 ∂ h 1 ∂ h 1 ∂ C 1 ∂ h 2 ∂ C 2 ∂ C 2 ∂ C 1 } } ∂ C 1 ∂ f 1 } ∂ f 1 ∂ W x f ( 1 ) = ∂ L 3 ∂ y 3 ∂ y 3 ∂ h 3 { ∂ h 3 ∂ C 3 ∂ C 3 ∂ C 2 ∂ C 2 ∂ C 1 { ∂ h 3 ∂ o 3 ∂ o 3 ∂ h 2 ∂ h 3 ∂ C 3 ∂ C 3 ∂ f 3 ∂ f 3 ∂ h 2 ∂ h 3 ∂ C 3 ∂ C 3 ∂ i 3 ∂ i 3 ∂ h 2 ∂ h 3 ∂ C 3 ∂ C 3 ∂ C ~ 3 ∂ C ~ 3 ∂ h 2 } ∂ h 2 ∂ o 2 ∂ o 2 ∂ h 1 ∂ h 1 ∂ C 1 ∂ h 3 ∂ o 3 ∂ o 3 ∂ h 2 ∂ h 2 ∂ C 2 ∂ C 2 ∂ C 1 } ∂ C 1 ∂ f 1 ∂ f 1 ∂ W x f ( 1 ) \begin{aligned} \frac{\partial{L_3}}{\partial{W_{xf}^{(1)}}} & = \frac{\partial{L_3}}{\partial{y_3}} \frac{\partial{y_3}}{\partial{h_3}} \left\{\begin{array}{lcl} t_3 \to t_2 \left\{\begin{array}{l} \frac{\partial{h_3}}{\partial{o_3}} \frac{\partial{o_3}}{\partial{h_2}} \\ \frac{\partial{h_3}}{\partial{C_3}} {\color{blue}{ \frac{\partial{C_3}}{\partial{C_2}} }}\\ \frac{\partial{h_3}}{\partial{C_3}} {\color{blue}{ \frac{\partial{C_3}}{\partial{f_3}} \frac{\partial{f_3}}{\partial{h_2}} }}\\ \frac{\partial{h_3}}{\partial{C_3}} {\color{blue}{ \frac{\partial{C_3}}{\partial{i_3}} \frac{\partial{i_3}}{\partial{h_2}} }}\\ \frac{\partial{h_3}}{\partial{C_3}} {\color{blue}{ \frac{\partial{C_3}}{\partial{\tilde{C}_3 }} \frac{\partial{\tilde{C}_3 }}{\partial{h_2}} }}\\ \end{array}\right\} \\ t_2 \to t_1 \left\{\begin{array}{l} \left\{\begin{array}{l} \color{red}{ \frac{\partial{C_2}}{\partial{f_2}} \frac{\partial{f_2}}{\partial{h_1}} \frac{\partial{h_1}}{\partial{C_1}} }\\ \color{red}{ \frac{\partial{C_2}}{\partial{i_2}} \frac{\partial{i_2}}{\partial{h_1}} \frac{\partial{h_1}}{\partial{C_1}} }\\ \color{red}{ \frac{\partial{C_2}}{\partial{\tilde{C}_2}} \frac{\partial{\tilde{C}_2}}{\partial{h_1}} \frac{\partial{h_1}}{\partial{C_1}} }\\ \color{red}{ \frac{\partial{C_2}}{\partial{C_1}} }\\ \end{array}\right\} \\ \left\{\begin{array}{l} \frac{\partial{h_2}}{\partial{o_2}} \frac{\partial{o_2}}{\partial{h_1}} \frac{\partial{h_1}}{\partial{C_1}} \\ {\color{blue}{ \frac{\partial{h_2}}{\partial{C_2}} }} \left\{\begin{array}{l} \color{green}{ \frac{\partial{C_2}}{\partial{C_1}} }\\ \color{green}{ \frac{\partial{C_2}}{\partial{f_2}} \frac{\partial{f_2}}{\partial{C_1}}} \\ \color{green}{ \frac{\partial{C_2}}{\partial{i_2}} \frac{\partial{i_2}}{\partial{h_1}} \frac{\partial{h_1}}{\partial{C_1}} }\\ \color{green}{ \frac{\partial{C_2}}{\partial{\tilde{C}_2}} \frac{\partial{\tilde{C}_2}}{\partial{h_1}} \frac{\partial{h_1}}{\partial{C_1}} }\\ \end{array}\right\} \\ \end{array}\right\} \end{array}\right\} \frac{\partial{C_1}}{\partial{f_1}}\\ \end{array}\right\} \frac{\partial{f_1}}{\partial{W_{xf}^{(1)}}} \\&= \frac{\partial{L_3}}{\partial{y_3}} \frac{\partial{y_3}}{\partial{h_3}} \left\{\begin{array}{l} t_3 \to t_2 \left\{\begin{array}{l} \frac{\partial{h_3}}{\partial{o_3}} \frac{\partial{o_3}}{\partial{h_2}} \\ \frac{\partial{h_3}}{\partial{C_3}} {\color{blue}{ \frac{\partial{C_3}}{\partial{C_2}} }}\\ \frac{\partial{h_3}}{\partial{C_3}} {\color{blue}{ \frac{\partial{C_3}}{\partial{f_3}} \frac{\partial{f_3}}{\partial{h_2}} }}\\ \frac{\partial{h_3}}{\partial{C_3}} {\color{blue}{ \frac{\partial{C_3}}{\partial{i_3}} \frac{\partial{i_3}}{\partial{h_2}} }}\\ \frac{\partial{h_3}}{\partial{C_3}} {\color{blue}{ \frac{\partial{C_3}}{\partial{\tilde{C}_3 }} \frac{\partial{\tilde{C}_3 }}{\partial{h_2}} }}\\ \end{array}\right\} \\ t_2 \to t_1 \left\{\begin{array}{l} {\color{red}{ \frac{\partial{C_2}}{\partial{C_1}}}}\\ \left\{\begin{array}{l} \frac{\partial{h_2}}{\partial{o_2}} \frac{\partial{o_2}}{\partial{h_1}} \frac{\partial{h_1}}{\partial{C_1}} \\ {\color{blue}{ \frac{\partial{h_2}}{\partial{C_2}} }} {\color{green} { \frac{\partial{C_2}}{\partial{C_1}}}} \end{array}\right\} \end{array}\right\} \frac{\partial{C_1}}{\partial{f_1}}\\ \end{array}\right\} \frac{\partial{f_1}}{\partial{W_{xf}^{(1)}}} \\& = \frac{\partial{L_3}}{\partial{y_3}} \frac{\partial{y_3}}{\partial{h_3}} \left\{\begin{array}{l} \frac{\partial{h_3}}{\partial{C_3}} {\color{purple}{ \frac{\partial{C_3}}{\partial{C_2}} }} {\color{red}{ \frac{\partial{C_2}}{\partial{C_1}}}}\\ \left\{\begin{array}{l} \frac{\partial{h_3}}{\partial{o_3}} \frac{\partial{o_3}}{\partial{h_2}} \\ \frac{\partial{h_3}}{\partial{C_3}} {\color{blue}{ \frac{\partial{C_3}}{\partial{f_3}} \frac{\partial{f_3}}{\partial{h_2}} }}\\ \frac{\partial{h_3}}{\partial{C_3}} {\color{blue}{ \frac{\partial{C_3}}{\partial{i_3}} \frac{\partial{i_3}}{\partial{h_2}} }}\\ \frac{\partial{h_3}}{\partial{C_3}} {\color{blue}{ \frac{\partial{C_3}}{\partial{\tilde{C}_3 }} \frac{\partial{\tilde{C}_3 }}{\partial{h_2}} }}\\ \end{array}\right\} \frac{\partial{h_2}}{\partial{o_2}} \frac{\partial{o_2}}{\partial{h_1}} \frac{\partial{h_1}}{\partial{C_1}} \\ \frac{\partial{h_3}}{\partial{o_3}} \frac{\partial{o_3}}{\partial{h_2}} {\color{blue}{ \frac{\partial{h_2}}{\partial{C_2}} }} {\color{green} { \frac{\partial{C_2}}{\partial{C_1}}}}\\ \end{array}\right\} \frac{\partial{C_1}}{\partial{f_1}} \frac{\partial{f_1}}{\partial{W_{xf}^{(1)}}} \end{aligned} Wxf(1)L3=y3L3h3y3t3t2o3h3h2o3C3h3C2C3C3h3f3C3h2f3C3h3i3C3h2i3C3h3C~3C3h2C~3t2t1f2C2h1f2C1h1i2C2h1i2C1h1C~2C2h1C~2C1h1C1C2o2h2h1o2C1h1C2h2C1C2f2C2C1f2i2C2h1i2C1h1C~2C2h1C~2C1h1f1C1Wxf(1)f1=y3L3h3y3t3t2o3h3h2o3C3h3C2C3C3h3f3C3h2f3C3h3i3C3h2i3C3h3C~3C3h2C~3t2t1C1C2{ o2h2h1o2C1h1C2h2C1C2}f1C1Wxf(1)f1=y3L3h3y3C3h3C2C3C1C2o3h3h2o3C3h3f3C3h2f3C3h3i3C3h2i3C3h3C~3C3h2C~3o2h2h1o2C1h1o3h3h2o3C2h2C1C2f1C1Wxf(1)f1

From the above calculation process, we can see that in the final derivation result, there will be a large number of cumulative multiplications in the following form,
⋯ ∏ t = mn ∂ C t ∂ C t − 1 ⋯ \cdots \prod_{t = m} ^{n} \frac{\partial{C_{t}}}{\partial{C_{t-1}}} \cdotst=mnCt1Ct

in,
∂ C t ∂ C t − 1 = ⊕ { ∂ C t ∂ f t ∂ f t ∂ h t − 1 ∂ h t − 1 ∂ C t − 1 ( = C t − 1 ⋅ W h f σ ′ ⋅ o t − 1 tanh ⁡ ′ ) ∂ C t ∂ i t ∂ i t ∂ h t − 1 ∂ h t − 1 ∂ C t − 1 ( = C ~ t ⋅ W h i σ ′ ⋅ o t − 1 tanh ⁡ ′ ) ∂ C t ∂ C ~ t ∂ C ~ t ∂ h t − 1 ∂ h t − 1 ∂ C t − 1 ( = i t ⋅ W h C ~ tanh ⁡ ′ ⋅ o t − 1 ⋅ tanh ⁡ ′ ) \frac{\partial{C_{t}}}{\partial{C_{t-1}}} = \oplus\begin{cases} \frac{\partial{C_{t}}}{\partial{f_t}} \frac{\partial{f_{t}}}{\partial{h_{t-1}}} \frac{\partial{h_{t-1}}}{\partial{C_{t-1}}} \left(= C_{t-1} \cdot W_{hf} \sigma' \cdot o_{t-1} \tanh'\right)\\ \frac{\partial{C_{t}}}{\partial{i_t}} \frac{\partial{i_{t}}}{\partial{h_{t-1}}} \frac{\partial{h_{t-1}}}{\partial{C_{t-1}}} \left(= \tilde{C}_{t} \cdot W_{hi}\sigma' \cdot o_{t-1}\tanh'\right)\\ \frac{\partial{C_{t}}}{\partial{\tilde{C}_t}} \frac{\partial{\tilde{C}_{t}}}{\partial{h_{t-1}}} \frac{\partial{h_{t-1}}}{\partial{C_{t-1}}} \left(= i_{t} \cdot W_{h\tilde{C}} \tanh' \cdot o_{t-1} \cdot \tanh'\right)\\ \end{cases} Ct1Ct=ftCtht1ftCt1ht1(=Ct1Whfpot1fishy)itCtht1itCt1ht1(=C~tWhipot1fishy)C~tCtht1C~tCt1ht1(=itWhC~fishyot1fishy)

It can be seen from the above formula that ∂ C t ∂ C t − 1 \frac{\partial{C_{t}}}{\partial{C_{t-1}}}Ct1CtThe value of can be adjusted by parameters W hf , W hi , W h C ~ W_{hf},W_{hi},W_{h\tilde{C}}Whf,Whi,WhC~To control flexibly, in order to prevent the gradient from disappearing, it can be controlled at 1 11 near, then there will be1 1In the case of multiplication by 1 , t = 1 , 2 , 3 t=1,2,3t=1,2,3. When more moments are considered, more and more multiplications will appear. In addition to being controlled by theabove-mentioned multiplications, other forms of gradient multiplications may make the gradient disappear, so∂ C t ∂ C t − 1 \frac{\partial{C_{t}}}{\partial{C_{t-1}}}Ct1CtThe cumulative multiplication greatly alleviates the problem of gradient disappearance.

from LSTM \rm LSTMThe working mechanism of L S T M is explained, that is,t = mt = mt=m tot = nt=nt=In the short period oftime n , the information remembered in the cell state is basically the same, which will also make the gradients of two adjacent moments almost the same, thereby relieving the gradient from disappearing.

References

[1] Video at Station B: [Revisiting the classics] Explaining how the LSTM long-term and short-term memory network alleviates the disappearance of gradients in vernacular, and deriving backpropagation with hands-on formulas

[2] Introduction to LSTM and derivation of backpropagation algorithm that everyone can understand (very detailed)

[3] Understanding LSTM Networks

[4] Understanding LSTM Networks

Guess you like

Origin blog.csdn.net/qq_41139677/article/details/120983071