cs224n学习笔记L6: Language models and RNNs

一、语言模型

1.1 什么是语言模型(LM)

语言模型是指利用当前序列预测下一个可能出现的词。
P ( x ( 1 ) , , x ( T ) ) = P ( x ( 1 ) ) × P ( x ( 2 ) x ( 1 ) ) × × P ( x ( T ) x ( T 1 ) , , x ( 1 ) ) = t = 1 T P ( x ( t ) x ( t 1 ) , , x ( 1 ) ) \begin{aligned} P\left(\boldsymbol{x}^{(1)}, \ldots, \boldsymbol{x}^{(T)}\right) &=P\left(\boldsymbol{x}^{(1)}\right) \times P\left(\boldsymbol{x}^{(2)} | \boldsymbol{x}^{(1)}\right) \times \cdots \times P\left(\boldsymbol{x}^{(T)} | \boldsymbol{x}^{(T-1)}, \ldots, \boldsymbol{x}^{(1)}\right) \\ &=\prod_{t=1}^{T} P\left(\boldsymbol{x}^{(t)} | \boldsymbol{x}^{(t-1)}, \ldots, \boldsymbol{x}^{(1)}\right) \end{aligned}
在这里插入图片描述
输入法的联想功能就是这个原理

1.2 n-gram语言模型

训练一个语言模型的一种简单方法是n-gram, n-gram是指大小为n滑动步长为1的窗口框选出的单词块。包括:

  • unigrams
  • bigrams
  • trigrams
  • 4-grams

1.2.1 n-gram数学原理

假设:第t+1个词 x ( t + 1 ) x^{(t+1)} 由前n-1个词决定:
在这里插入图片描述
由条件概率的定义:
在这里插入图片描述
可以通过直接统计这么gram的频数获得这些n-gram和(n-1)-gram的概率。即:
P count ( x ( t + 1 ) , x ( t ) , , x ( t n + 2 ) ) count ( x ( t ) , , x ( t n + 2 ) ) P\approx \frac{\operatorname{count}\left(\boldsymbol{x}^{(t+1)}, \boldsymbol{x}^{(t)}, \ldots, \boldsymbol{x}^{(t-n+2)}\right)}{\operatorname{count}\left(\boldsymbol{x}^{(t)}, \ldots, \boldsymbol{x}^{(t-n+2)}\right)}

1.2.2 n-gram缺点

以4-grams为例:

  • n-gram模型确实十分简单,但是为了简化问题,丢掉上下文显然是不合理的假设。
  • 稀疏问题: 如果在测试集中出现了一个词,这个词在训练集中没有出现过
    • 导致未登录词的概率为0(解决方案:给每一个词都初始化一个很小的概率,即平滑操作,从而将计数的稀疏矩阵转换为稠密矩阵)
    • 导致分母为零, 且包含这个未登录词的三元组计数都为0(解决办法:只看前两个词/前一个词)。
  • 存储问题:计数需要存储所有出现的n-gram, 复杂度可能为n的指数量级。

此外,当n越大时,理论上预测效果越好,但稀疏问题会变得非常严峻。

1.3 为什么要研究语言建模

  • 语言建模是衡量我们对自然语言理解的一项基本任务(benchmark task)
  • 这是很多语言任务的子模块之一,例如:
    • 预测输入
    • 语音识别
    • 手写字符识别
    • 拼写/语法检查
    • 机器翻译
    • 自动摘要

1.4 LM理解

LM是指一个用来预测下一个词的系统,RNN不是LM,只是一种实现LM的方式。

二、神经网络语言模型

2.1 基于窗口的语言模型

上一节课我们用到了这个模型
在这里插入图片描述
这个模型带来以下提升:

  • 没有稀疏问题
  • 不需要存储所有的grams片段

仍然存在的问题:

  • 固定的窗口太小
  • 增大窗口就需要增大W
  • 由于窗口再大都不为过,那么W也就可能无限增大
  • 权重中每个部分只学到了窗口中某个位置的信息。

因此我们需要一个能够处理任意长度的神经网络结构。

2.2 RNN

注意下面这里的 y ^ \hat y 实际上就是RNN隐藏层的状态向量。
在这里插入图片描述
当t=0时,h0就是隐藏层初始状态,t>0时,隐藏层的计算公式如下:
h ( t ) = σ ( W h h ( t 1 ) + W e e ( t ) + b h ) (2.2.1) h^{(t)} = \sigma (W_hh^{(t-1)} + W_ee^{(t)} + \boldsymbol b_h) \tag {2.2.1}
y ^ ( t ) = s o f t m a x ( U h ( t ) + b U ) \hat y^{(t)} = softmax(Uh^{(t)} + \boldsymbol b_U)
这里的 W h W_h 类似马尔科夫链中的转移矩阵。
在这里插入图片描述

2.3 RNN优缺点

优点:

  • 处理任意长度输入序列
  • 计算一个时间步可以利用前面多个时间步的信息
  • 由于对每个时间步使用同一个权重,学习到的信息是对称的

缺点:

  • 并行度低,计算慢
  • 实际使用时只有短距离的信息依赖

2.4 如何训练RNN模型

  1. 对于一个语料库构成的一个文本序列 x ( 1 ) , , x ( T ) x^{(1)}, \cdots, x^{(T)} , 将其喂给RNN,对每个 x ( t ) x^{(t)} , 都输出一个 y ^ ( t ) \hat y^{(t)} (一个n-class维向量)

  2. 损失函数是在每一个 y ^ ( t ) \hat y^{(t)} 处使用一个softmax函数将其转换为概率分布,并用交叉熵计算其损失
    J ( t ) ( θ ) = C E ( y ( t ) , y ^ ( t ) ) = w V y w ( t ) log y ^ w ( t ) = log y ^ x t + 1 ( t ) J ^ { ( t ) } ( \theta ) = C E \left( \boldsymbol { y } ^ { ( t ) } , \hat { \boldsymbol { y } } ^ { ( t ) } \right) = - \sum _ { w \in V } \boldsymbol { y } _ { w } ^ { ( t ) } \log \hat { \boldsymbol { y } } _ { w } ^ { ( t ) } = - \log \hat { \boldsymbol { y } } _ { x t + 1 } ^ { ( t ) }

  3. 将这T个交叉熵平均作为整体损失:
    J ( θ ) = 1 T t = 1 T J ( t ) ( θ ) = 1 T t = 1 T log y ^ x t + 1 ( t ) J ( \theta ) = \frac { 1 } { T } \sum _ { t = 1 } ^ { T } J ^ { ( t ) } ( \theta ) = \frac { 1 } { T } \sum _ { t = 1 } ^ { T } - \log \hat { \boldsymbol { y } } _ { \boldsymbol { x } _ { t + 1 } } ^ { ( t ) }

  4. 老规矩,在整个数据集上面计算一次损失并更新代价太大,因此使用SGD,通常将语料中的一句话作为一个序列,每次更新使用多句话构成一个batch.

2.5 RNN反向传播

2.5.1 基本计算公式及其推导

如何计算 J ( t ) ( θ ) J^{(t)}(\theta) 对重复使用的权重矩阵 W h W_h 的偏导?先给出ppt的答案:
J ( t ) W h = i = 1 t J ( t ) W h ( i ) \frac { \partial J ^ { ( t ) } } { \partial \boldsymbol { W } _ { \boldsymbol { h } } } = \left. \sum _ { i = 1 } ^ { t } \frac { \partial J ^ { ( t ) } } { \partial \boldsymbol { W } _ { \boldsymbol { h } } } \right| _ { ( i ) }
来看看课堂上的推导讲解,这里要知道在每一个样本的计算中, W h = W h ( 1 ) = = W h ( T ) W_h = W_h|_{(1)} = \cdots = W_h|_{(T)} , 因为训练过程中每一个batch计算完才会更新一次。
在这里插入图片描述
这个看着可能还是不太能理解,我们从数学公式上下手也很简单。根据公式2.2.1:
h ( t ) = σ ( W h h ( t 1 ) + W e e ( t ) + b h ) h^{(t)} = \sigma (W_hh^{(t-1)} + W_ee^{(t)} + b_h)
u = W h h ( t 1 ) + W e e ( t ) + b h u = W_hh^{(t-1)} + W_ee^{(t)} + \boldsymbol b_h , 由于 h ( t 1 ) h^{(t-1)} 是W的函数, J ( t ) J^{(t)} 为与 h ( t ) h^{(t)} 相关的交叉熵损失,因此直接考虑
h ( t ) W h = h ( t ) σ ( u ) ( u W + u h ( t 1 ) h ( t 1 ) W ) = h ( t ) W h ( t ) + λ h ( t 1 ) W h \begin{aligned} \frac { \partial h ^ { ( t ) } } { \partial { W } _ { \boldsymbol { h } } } & =\frac { \partial h ^ { ( t ) } } {\sigma(u)} (\frac{\partial u}{\partial W} + \frac{\partial u}{\partial h^{(t-1)}}\frac{\partial h^{(t-1)}}{\partial W}) \\ &= \left. \frac { \partial h ^ { ( t ) } } {\partial W_h}\right|_{(t)}+ \lambda\frac { \partial h ^ { ( t -1) } } {\partial W_h} \end{aligned}
可以看到这个式子能够被递归展开,而且展开的深度为t。要注意这里是对所有time-step权重求偏导,而反向传播时是一步一步的计算,不需要递归展开。

2.5.2 时间序上的反向传播

由于RNN的真实输出为最后一个time-step的隐藏层 h ( T ) h^{(T)} , 因此反向传播要从这里开始,计算最后一个时间步对W的梯度。整个方向传播的结果就相当于2.5.1中的计算公式。
在这里插入图片描述

2.5.3 RNN-LM文本预测及生成

直接预测下一个time-step的隐藏层向量即可。RNN可以模仿任意语言风格,甚至能够精确的学习到语言中的引号匹配等问题。

2.6 RNN其他用法

  1. 序列标注(词性标注、命名实体识别)
    在这里插入图片描述
  2. 情感分类
    在这里插入图片描述
  3. 用作编码器模块(encoder module)
    在这里插入图片描述
  4. 用于语音识别、手写字符识别、自动摘要等文本生成任务
    在这里插入图片描述

三、评价语言模型:困惑度(perplexity)

使用困惑度评价指标,直观上理解该公式,就是要使LM预测给定语料的概率尽可能大。
p e r p l e x i t v = t = 1 T ( 1 P L M ( x ( t + 1 ) x ( t ) , , x ( 1 ) ) ) 1 / T = t = 1 T ( 1 y ^ x t + 1 ( t ) ) 1 / T = exp ( 1 T t = 1 T log y ^ x t + 1 ( t ) ) = exp ( J ( θ ) ) \begin{aligned} perplexitv &= \prod_{t=1}^{T}(\frac{1}{P_{LM}(x^{(t+1)}|x^{(t)},\cdots,x^{(1)})})^{1/T} \\ &= \prod _ { t = 1 } ^ { T } \left( \frac { 1 } { \hat { \boldsymbol { y } } _ { x _ { t + 1 } } ^ { ( t ) } } \right) ^ { 1 / T } \\&= \exp \left( \frac { 1 } { T } \sum _ { t = 1 } ^ { T } - \log \hat { \boldsymbol { y } } _ { \boldsymbol { x } _ { t + 1 } } ^ { ( t ) } \right) \\ &= \exp ( J ( \theta ) ) \end{aligned}
这与我们训练时的目标函数一致,越小的困惑度的模型越好。除此之外,也可以使用word error rate来评价模型。

四、术语笔记

  1. 本节课中的RNN被称作普通RNN(vanilla RNN)
  2. RNN升级版:GRU、LSTM、多层RNN
  3. 课程结束应该能理解“stacked bidirectional LSTM with residual connections and self-attention”这样的术语。
发布了24 篇原创文章 · 获赞 8 · 访问量 7203

猜你喜欢

转载自blog.csdn.net/geek_hch/article/details/104520145
今日推荐