seq2seq model: beam search和attention机制理解

1 seq2seq模型结构

1.1 Encoder和Decoder

seq2seq模型也叫encoder-decoder模型,通过对原始输入进行编码表示,生成目标输出。比如:机器翻译(machine transiation),将一种自然语言翻译成另外一种自然语言,还比如图像描述(image caption),将图像进行编码,解码器生成对应的图像描述文字。而一般用RNN结构模型作为编码和解码器,接下来我们以机器翻译为例,做一个详细的解说,假设我们将目标语言英语翻译成法语: how are you ---->comment vas tu

1.1.1 Encoder(编码器)

编码器,我们将英语:how are you用一个RNN结构进行特征提取,首先我们来看一个最原生的编码器的结构如下:
在这里插入图片描述
首先每个输入word从词向量表中找到对应的词向量 w ∈ R d w \in R^d wRd,然后输入到LSTM中,每个word,在一个time step中会得到hidden state的输出,如上图所示,每个word对应一个输出向量表示: [ e 0 , e 1 , e 2 ] [e_0,e_1,e_2] [e0,e1,e2],用序列的最后一个word的hidden state输出表示编码器的表征 e e e(在这里,上图中的 e 2 = e e_2=e e2=e

1.1.2 Decoder(解码器)

从Encoder编码器中,我们得到了能够表征输入序列的向量 e e e,现在我们就可以在Decoder解码器word by word,进行目标序列的生成了。在解码器中,是另外一个LSTM。首先我们来看LSTM在第一个step的输入:由编码器生成的向量 e e e作为hidden state输入,以及目标生成的句的一个特别的开始输入表征向量 w s o s w_{sos} wsos(开始以及结束标识是人为的加进去的标识),通过LSTM得到hidden state输出向量 h 0 ∈ R h h_0 \in R^h h0Rh,再经过一层分类器 g g g(分类的label数量为目标语言的词数量大小),将 R h − > R V R^h -> R^V Rh>RV,再经过softmax获得概率分布值,而在这里我们需要预测的是"comment"这个词,所以label标签为"comment"索引对应的one-hot label表示,然后将预测出来的"comment"向量和上一个step的hidden state h 0 h_0 h0重复上述过程输入到下一个step中。解码器的结构图如下:
在这里插入图片描述
计算结果表示如下:
h 0 = L S T M ( e , w s o s ) h_0 = LSTM(e, w_{sos}) h0=LSTM(e,wsos)
s 0 = g ( h 0 ) s_0 = g(h_0) s0=g(h0)
p 0 = s o f t m a x ( s 0 ) p_0 = softmax(s_0) p0=softmax(s0)
i 0 = a r g m a x ( p 0 ) i_0 = argmax(p_0) i0=argmax(p0)
. . . ... ...

1.2 attention机制

1.2.1 为什么需要attention

attention机制能够让模型在每一步的解码阶段,只需要将注意力集中在输入序列的关键几个词上,而不需整个编码阶段的所有输入信息,这不仅符合人的常识原理,而且对模型来说也会减少梯度消失,学习困难等问题。

1.2.2 attention结构

在上面的Decoder阶段,模型的输入包含两个部分,假设在时间 t t t时刻上,模型的输入是 t − 1 t-1 t1时刻的hidden state输出 h t − 1 h_{t-1} ht1 t − 1 t-1 t1时刻的预测的目标词 w t − 1 w_{t-1} wt1对应的向量(训练的时候,预测的结果已知,只需要把上一时刻正确的预测的词向量输入到模型),attention结构的使用,在解码阶段,输入的时候多加一个context向量 c t c_t ct,公式改动如下:

h t = L S T M ( h t − 1 , [ w i t − 1 , c t ] ) h_t = LSTM(h_{t-1}, [w_{i_{t-1}},c_{t}]) ht=LSTM(ht1,[wit1,ct])
s t = g ( h t ) s_t = g(h_t) st=g(ht)
p t = s o f t m a x ( s t ) p_t = softmax(s_t) pt=softmax(st)
i t = a r g m a x ( p t ) i_t = argmax(p_t) it=argmax(pt)
向量 c t c_t ct就是attention向量(或者称为context向量),在解码的每个step里,都会对应一个不同的attention向量。那么attention向量怎么计算呢?核心思路就是:在decoder当前 t t t时刻,模型需要关注encoder哪些核心词。通过上一时刻decoder的hidden state输出 h t − 1 h_{t-1} ht1与encoder的每个hidden state输出,通过函数方程计算一个权重分值,然后用这个权重分值对应乘以encoder的hidden state输出向量,得到context向量 c t c_t ct,具体计算步骤如下:

  • Step 1: 计算attention权重
    a t ′ = f ( h t − 1 , e t ′ )    f o r   a l l   t ′ a_{t^{'}} = f(h_{t-1}, e_{t^{'}}) \text{ } \text{ } for \text{ }all \text{ } t^{'} at=f(ht1,et)  for all t

  • Step 2: 权重归一化
    a t ′ ^ = s o f t m a x ( a t ′ ) \hat{a_{t^{'}}} = softmax(a_{t^{'}}) at^=softmax(at)

  • Step 3: 计算attention向量 c t c_t ct
    c t = ∑ t ′ = 0 n a ^ t ′ e t ′ c_t = \sum_{t^{'}=0}^{n} \hat{a}_{t^{'}}e_{t^{'}} ct=t=0na^tet

加入attention机制后模型结构图如下所示:
在这里插入图片描述

1.2.3 attention计算方法

那怎么选择函数 f f f计算attention权重呢?一般有如下几种方法:
f ( h t − 1 , e t ′ ) = { h t − 1 T e t ′   dot h t − 1 T W e t ′    g e n e r a l v T t a n h ( W [ h t − 1 , e t ′ ] )    c o n c a t f(h_{t-1}, e_{t^{'}}) = \begin{cases} h^T_{t-1}e_{t^{'}} \text{ } \text{ dot}\\\\ h^T_{t-1}We_{t^{'}} \text{ } \text{ } general\\\\ v^Ttanh(W[h_{t-1}, e_t{^{'}}]) \text{ } \text{ } concat\end{cases} f(ht1,et)=ht1Tet  dotht1TWet  generalvTtanh(W[ht1,et])  concat

2 模型训练

2.1 数据格式

在解码阶段,核心就是预测每个输入预测对应的label,而这个label就是当前输入的下一个词,我们还是以翻译为例,英语翻译成法语:how are you ->comment vas tu
encoder:把输入序列经过LSTM编码器,计算每个序列词的输出向量以此用来计算在解码阶段的每个目标输入的attention向量。
decoder:解码阶段通过每个时间步输入预测下一个词的label,如下是每个输入词对应的预测label:

<sos> comment vas tu
comment vas tu <eos>

如表格的第一列表示,当输入<sos>时候,对应的预测是"comment",label表示就是"comment"在整个词表 V V V对应的one-hot label向量,同理,后面的一一对应。解码过程结构图如下:
在这里插入图片描述

2.2 优化目标函数

首先我们的目标就是最大优化如下概率函数:
P ( y 1 , . . . , y m ) = ∏ i = 1 m p i [ y i ] P(y_1, ..., y_m) = \prod^m_{i=1}p_i[y_i] P(y1,...,ym)=i=1mpi[yi]
其中 p i [ y i ] p_i[y_i] pi[yi]表示在解码的第 i i i个step上,从概率向量 p i p_i pi中抽取的第 y i y_i yi个序列的概率值。目标是训练模型最大化这个序列概率值,由于每个概率值都是小于1的分值,多个累乘有可能出现数值越界等问题,所以进行转换,等价最小化如下函数:
− l o g P ( y 1 , . . . y m ) = − l o g ∏ i = 1 m p i [ y i ] = − ∑ i = 1 n l o g p i [ y i ] -logP(y_1,...y_m) = -log \prod^m_{i=1}p_i[y_i]=-\sum^n_{i=1}logp_i[y_i] logP(y1,...ym)=logi=1mpi[yi]=i=1nlogpi[yi]
在我们的例子中,这个公式等于:
− l o g p 1 [ c o m m e n t ] − l o g p 2 [ v a s ] − l o g p 3 [ t u ] − l o g p 4 [ < e o s > ] -logp_1[comment] - logp_2[vas] - logp_3[tu] - logp_4[<eos>] logp1[comment]logp2[vas]logp3[tu]logp4[<eos>]
从上面展开形式来看,其实就是在最小化目标分布(one-hot vectors)和模型预测的分布( p i p_i pi)两个分布的交叉熵,本质上,在解码阶段也是对每个输入的词进行下一个词预测的分类模型,对每个输入的word进行类别分类,分类的label是整个目标语言的词的数量 V V V的大小。

3 模型预测

在预测解码阶段,我们希望预测最大概率的序列,如果不考虑计算代价,假设词表大小为 V V V,模型预测到结束标识符假设长度为 L L L,则总共有 V L V^L VL种序列,计算复杂度为 O ( V L ) O(V^L) O(VL)。只有全部搜索所有的组合序列,才可以获得全局最优解,但是计算代价太大,在实际应用场景中,实时性要求高的任务,更加不可能这么计算。所以,在解码阶段我们能否有一种代价较小的计算方法?

3.1 greedy decoding

贪心计算,是一种很常见的方法,也就是在每个step里,选取一个最大的概率分值,如下图所示:
在这里插入图片描述
假设解码目标词表为 V V V,解码长度为 L L L,贪心算法每次只需要保留概率最大的词,则上面的复杂度由 O ( V L ) O(V^L) O(VL)变成 O ( V ∗ L ) O(V*L) O(VL),大大降低了计算的复杂度,提高了解码的速度。
然而贪心搜索并不能保证每次都预测准确,如果有一步预测错误,会将错误一直累积下去,导致预测的序列结果误差越来越大。

3.2 beam search

3.2.1 为什么需要beam search

算出所有的可能序列,求最优的序列,计算代价太大。而用greedy search方法,虽然可以大大降低复杂度,但是准确率会降低,而且只要解码有一步出错,误差累积会越来越大。
那有没有一个折中的策略?当然有!在解码阶段的每一个step里,每次保留top K个最优序列,这样既降低了计算开销,而且也提高了预测的准确率,这就是beam search,也是我们为啥选择用beam search的理由:

  • 解码截断降低了时间复杂度
  • 相对greedy search,准确率更高

3.2.2 beam search运行过程

假设如下的一个seq2seq翻译模型:
在这里插入图片描述
如果基于greedy search,在每个step,我们只需要选择概率分值最大的那个词,如下图过程演示:
在这里插入图片描述
如果用beam search,在每个step中保留top K个最大的概率分值,K越大,计算代价也就越大,但相对准确率也会提高。假设我们令K=3,则解码过程如下:

  • 计算step1中预测词表中的每个词,选取概率分值最大的三个word,假设是: “Jane”, “in", “September”
  • 用step1的分值最大的三个词依次输入rnn模型,计算整个词表中结果,选取分值最大的三个词,则此时产生三个序列为:”Jane is",“Jane visits",”Jane visited",如下图过程:
    在这里插入图片描述
  • 按照上述"Jane"的过程,此时对"in"解码得到:“in Africa”,“in visit",”in September"结果,如下图所示:
    在这里插入图片描述
  • 同样,“September"解码得到:“September would”,”September was",“September is",如下图过程所示:

在这里插入图片描述

  • 全部计算完后,总共有9个序列,从中又选取3个解码过程分值最高的3个序列,再依次进行下一个step的计算过程,如下图所示选取的top3个最大概率序列:

在这里插入图片描述
若Beam Width=3,在解码的每个step选取top 3个预测概率分值最大的词。而当Beam Width=1的时候,每个step选取top 1个预测概率分值最大的词,此时也就相当于greedy search。Beam width越大,会获得更好的结果,同时也需要更多的计算资源。

3.2.3 关于beam width 选择

上一节所说,width设置的越大,则计算代价越大,但准确率会更高,可以根据自己的实际任务去做一个效率和准确率的折中。对预测错误的数据,如果是因为beam width调的太小,最终正确的词没有进候选词,那么就需要调大beam width,如果是因为模型本身预测出错,则需要优化模型的预测能力。

3.2.4 与viterbi区别

首先我们先来看beam search的一个特性:

  • 基于贪心算法思想,每次在解码过程中只考虑top K个最优候选结果,而不是整个词表,是一种启发式算法,缩小搜索空间,提高了效率,但不是全局最优解
  • 只在预测test的时候需要用,train的时候知道了结果不需要beam search
  • 适合于搜索空间非常大的任务,比如自然语言生成过程,一般解码的词表都是非常大

而viterbi (维特比) 特性:

  • 基于动态规划思想,当前步骤根据前一步全部预测结果最高值推测当前全部可能结果最高值,属于全局最优解
  • 适用于搜索空间较小的任务,比如CRF序列标注等任务,预测的label数量不大。

4 参考

https://medium.com/hackernoon/beam-search-attention-for-text-summarization-made-easy-tutorial-5-3b7186df7086

猜你喜欢

转载自blog.csdn.net/BGoodHabit/article/details/109144775