LSTM 详解

这篇文章打算讲一下LSTM,虽然这类文章已经很多了,但以前刚开始看的时候还是一知半解,有一些细节没有搞清楚,我打算借这篇文章好好梳理一下。

前言

在许多讲LSTM的文章中,都会出现下面这个图。
Alt text
在这里插入图片描述
说实话,这个图确实很清晰明了(对于懂的人来说)。在很多文章中我都发现了这样的问题,有的时候,对于已经明白的人,一些很“显然”的问题就被忽略了,但是对于刚入门的人来说,一些基础的问题却要搞很久才能弄明白。所以,我希望在这里能尽可能讲的“慢”一些,把细节部分都讲清楚。

当然了,在看这篇文章之前,大家应该对RNN有一个基础的了解。

LSTM的大体结构

相比于原始的RNN的隐层(hidden state), LSTM增加了一个细胞状态(cell state),我下面把lstm中间一个时刻t的输入输出标出来:在这里插入图片描述

我们可以先把中间那一坨遮起来,看一下LSTM在t时刻的输入与输出,首先,输入有三个: 细胞状态 C t 1 C_{t-1} ,隐层状态 h t 1 h_{t-1} , t t 时刻输入向量 X t X_t ,而输出有两个:细胞状态 C t C_t , 隐层状态 h t h_t 。其中 h t h_t 还作为 t t 时刻的输出。

至于绿色框内部的结构与逻辑,我会在下面详细的讲,不过当前,我们从这个图里,只需要看出个大概就行了:

  1. 细胞状态 C t 1 C_{t-1} 的信息,一直在上面那条线上传递, t t 时刻的隐层状态 h t h_t 与输入 x t x_t 会对 C t C_t 进行适当修改,然后传到下一时刻去。
  2. C t 1 C_{t-1} 会参与 t t 时刻输出 h t h_t 的计算。
  3. 隐层状态 h t 1 h_{t-1} 的信息,通过LSTM的“门”结构,对细胞状态进行修改,并且参与输出的计算。

总的来说呢,细胞状态的信息一直在上面那条线上传递,隐层状态一直在下面那条线上传递,不过它们会有一些交互,在LSTM中,通常被叫做“门”结构。

LSTM的输入输出

LSTM也是RNN的一种,输入基本没什么差别。通常我们需要一个时序的结构喂给LSTM,数据会被分成 t t 个部分,也就是上面图里面的 X t X_t X t X_t 可以看作是一个向量 ,在实际训练的时候,我们会用batch来训练,所以通常它的shape是**(batch_size, input_dim)**。当然我们来看这个结构的时候可以认为batch_size是1,理解和计算之类的也比较简单。

另外还有一点想啰嗦一下, C 0 C_0 h 0 h_0 的值,也就是两个隐层的初始值,一般是用全0初始化。两个隐层的同样是向量的形式,在定义LSTM的时候,会定义隐层大小(hidden size),即 S h a p e ( C t ) = S h a p e ( h t ) = H i d d e n S i z e Shape(C_t) = Shape(h_t) = HiddenSize 。输出的维度与对应输入是一致的。

LSTM的门结构

LSTM的门结构,简单来说,就是被设计出来的一些计算步骤,通过这些计算,来调整输入与两个隐层的值。

这里首先讲一下图里的几个组件吧。在这里插入图片描述
首先是上面这几个黄色的图案,这东西代表一个“神经元”,也就是 w T x + b w^T x + b 的操作。区别在于使用的激活函数不同, σ \sigma 表示sigmoid函数,它的输出是在0到1之间的, t a n h tanh 是双曲正切函数,它的输出在-1到1之间。

在这里插入图片描述
然后是这个粉色的操作,根据图里画的内容,操作略有不同,不过整体的意思就是向量的按元素操作。具体解释起来可能还比较麻烦,在当前的场景下,可以认为是,两个相同维度的向量,对应的元素进行圆圈内部的操作,比如✖️就是两个相同维度对应元素的乘积组成新的向量。

遗忘门在这里插入图片描述

首先说一下 [ h t 1 , x t ] [h_{t-1}, x_t] 这个东西就代表把两个向量连接起来(操作与numpy.concatenate相同)。然后 f t f_t 就是一个网络的输出,看起来还是很简单的。

然而它为什么叫遗忘门呢,下面是我自己的看法,前面也说了 σ \sigma 的输出在0到1之间,这个输出 f t f_t 逐位与 C t 1 C_{t-1} 的元素相乘,我们可以发现,当 f t f_t 的某一位的值为0的时候,这 C t 1 C_{t-1} 对应那一位的信息就被干掉了,而值为(0, 1),对应位的信息就保留了一部分,只有值为1的时候,对应的信息才会完整的保留。因此,这个操作被称之为遗忘门,也算是“实至名归”。

更新门层在这里插入图片描述

这个门有两个部分,一个是 C t ~ \tilde{C_t} ,这个可以看作是新的输入带来的信息, t a n h tanh 这个激活函数讲内容归一化到-1到1。另一个是 i t i_t ,这个东西看起来和遗忘门的结构是一样的,这里可以看作是新的信息保留哪些部分。
在这里插入图片描述
下面的操作就是对 C t C_t 进行更新,这个公式表示什么呢?看左边,就是前面遗忘门给出的 f t f_t ,这个值乘 C t 1 C_{t-1} ,表示过去的信息有选择的遗忘(保留)。右边也是同理,新的信息 C t ~ \tilde{C_t} i t i_t 表示新的信息有选择的遗忘(保留),最后再把这两部分信息加起来,就是新的状态 C t C_t 了。

输出门层

在这里插入图片描述
最后就是lstm的输出了,此时细胞状态 C t C_t 已经被更新了,这里的 o t o_t 还是用了一个sigmoid函数,表示输出哪些内容,而 C t C_t 通过 t a n h tanh 缩放后与 o t o_t 相乘,这就是这一个timestep的输出了。

公式总结

上面说了这些,看起来可能有些复杂,其实就是那几个公式,这里再把公式总结一下:

f t = σ ( W f [ h t 1 , x t ] + b f ) f_t = \sigma(W_f\cdot[h_{t-1}, x_t] + b_f)
i t = σ ( W i [ h t 1 , x t ] + b i ) i_t = \sigma(W_i\cdot[h_{t-1}, x_t] + b_i)
C t ~ = t a n h ( W c [ h t 1 , x t ] + b c ) \tilde{C_t}=tanh(W_c\cdot[h_{t-1}, x_t] + b_c)
C t = f t C t 1 + i t C t ~ C_t=f_t*C_{t-1} + i_t*\tilde{C_t}
o t = σ ( W o [ h t 1 , x t ] + b o ) o_t= \sigma(W_o\cdot[h_{t-1}, x_t] + b_o)
h t = o t tanh ( C t ) h_t= o_t * \tanh(C_t)

参数计算

上面说了lstm的原理与公式,这里想再讲一下参数是怎么计算的。简单来说,就是上面公式的 W W b b 包含的参数数量。

W W 的话,就是输入维度乘输出维度, b b 的参数量就是加上输出维度。

上面的公式中, W W 有四个: W f , W i , W c , W o W_f,W_i,W_c,W_o ,同样, b b 也是四个 b f , b i , b c , b o b_f, b_i, b_c, b_o

我们假设输入 x t x_t 这个向量的维度是 512 512 ,lstm的隐层数是 256 256 ,根据这个来实际计算一下参数量。

首先是输入 [ h t 1 , x t ] [h_{t-1}, x_t] ,这个前面说过,是两个向量连接起来,因此维度相加: 256 + 512 = 768 256 + 512 = 768
因为隐层是 256 256 ,所以输出就是 256 256 维。 W W 的参数量就是 768 × 256 = 196608 768 \times 256 = 196608 , b b 的参数是 256 256
所以最终的参数量就是: ( 768 × 256 + 256 ) × 4 = 787456 (768 \times 256 + 256) \times 4 = 787456

另外,pytorch中的lstm实现稍有不同,其公式如下:
在这里插入图片描述

上面的 g t g_t 其实就是 C t ~ \tilde{C_t} ,其他符号基本是一致的。可以看到,pytorch中, x t x_t h t 1 h_{t-1} 并没有拼接在一起,而是各自做了对应的运算,这其实就是使用了分块矩阵的技巧进行计算,结果理论上是一样的,不过这里有些不同的就是加了两个bias,因此计算偏置的参数需要乘2。

发布了443 篇原创文章 · 获赞 149 · 访问量 55万+

猜你喜欢

转载自blog.csdn.net/qian99/article/details/88628383