RNN的简单理解

本文的内容主要来自于斯坦福大学FeiFei-Li的CS231n课程,Lecture10,在这里做一个简单的总结,有兴趣的同学可以去看一下这个课程,讲的很好。

1. RNN

RNN的用途:

RNN主要用于序列处理,比如机器翻译,这种输入输出序列之间具有高度的相关性,RNN可以model这种关系,总结一下,按照输入输出的类型,RNN可以做以下几个事情:

1

举几个例子:
one-to-one: CNN
one-to-many: Image Caption
many-to-one: MNIST(glimpse输入)字符分类
many-to-many: 机器翻译

RNN的基本单元:

2

后一个状态由前一个状态以及当前输入决定,fw可以取为tanh等函数。
其中,参数W共用,不随时间改变而改变:

3

RNN的优化:

首先,传统的BP需要这样更新:

4

但是这样存在的问题是,如果输入序列太长,那么参数无法很好的进行更新,梯度在这个过程中很容易消失,因此,很多情况下RNN采用这种分段的BP算法:

5

Image Caption:

Image Caption是一个很好的将CV与NLP结合在一起的应用场景,输入一幅图片,输出一段话对当前图片进行描述,目前的state of the art的效果都是基于RNN实现的,一种经典的RNN用于Image Caption的做法是:

6

前面先用一个pre-trained的CNN结构(ResNet,VGG)对输入图像提取特征,用后面的FC层(如果有的话)特征作为RNN的输入,准确来说,是使用这些特征来对RNN进行状态初始化,接下来的输入输出序列都是自然语言,用当前单词预测下一个单词应该输出什么,直到输出一个终止符(句号)为止。

2. LSTM

LSTM早在1997年就首次提出,然而直到现在才被广泛的采用。我们首先来看一下传统的RNN存在什么问题:

7

对于经典的RNN单元,只有一个状态ht存在,每计算一次ht到h(t-1)的梯度,都需要乘上整个参数矩阵W,而且有非线性函数tanh的存在,这样的话,当序列很长的时候,整个梯度的计算量巨大,而且由于W矩阵的连乘,使得梯度很有可能爆炸以及消失,最终的表现就是训练困难。

8

LSTM很好的解决了这个问题:

9

可见相比于传统的RNN结构,LSTM增加了i,f,o,g几个门结构,这使得整个LSTM存在两个隐含状态ht和ct;

10

LSTM增加的门结构,可以很好的控制隐含层状态随着时间变化的信息流动,这在机器翻译问题中的一个直接体现就是,当前单词与前面的单词存在联系,但是可能只有前面的若干个单词直接影响了这个单词,而不是前面所有单词,这样的话LSTM可以选择性的控制前面单词对当前单词的影响程度。
更为重要的一点是,LSTM结构更易于优化,因为它解决了上面提到的RNN存在的梯度流复杂的问题:

11

可以看到状态ct的梯度流并不涉及直接的对于整个参数矩阵W的连乘,使得计算量小了很多,梯度爆炸和消失的可能行降低很多。

12

这就使得整个梯度在LSTM里面流动非常顺畅,易于优化。
这与ResNet的结构非常相似,ResNet也是利用了跳级连接,从而使得梯度流变得流畅,从而使得加深网络结构成为可能。

最后谈一点自己的理解,LSTM能在机器翻译,Image Caption等任务上取得很好的效果,主要得益于输入序列间的前后文关系,以及输入输出是可以稀疏表达的,LSTM隐含层的数量相比于CNN的FC层少很多,但是一个状态下这些参数W只能根据当前的输入更新,因此LSTM不见得比CNN更易于训练。

这篇文章对于LSTM的介绍挺好,并且给出了一个用LSTM做MNIST字符识别的Tensorflow代码,对于大家提高自己对于LSTM的理解很有帮助。

猜你喜欢

转载自blog.csdn.net/zhangboshen/article/details/79468455