Attention is all you need

Attention is all you need

3 模型结构

    大多数牛掰的序列传导模型都具有encoder-decoder结构. 此处的encoder模块将输入的符号序列\((x_1,x_2,...,x_n)\)映射为连续的表示序列\({\bf z} =(z_1,z_2,...,z_n)\)(特征表示), decorder 模块将\(z\)输出为符号序列\((y_1,y_2,...,y_n)\), 该模块每次只生成一个符号.每一步,模型都是自回归的即生成下一个符号要使用之前生成的符号序列作为输入.

本文的encorder-decorder结构如下图所示:

transformer.png

上图中左侧是encorder stack,右侧是decorder stack

    Encorder: encorder是由6个完全一样的block构成的. 每个block包含2个sub-layer,第一个sub-layer是multihead self-Attention机制, 第二个是position-wise的全连接层. 每个sub-layer都引入残差连接, 后面都要接一个归一化层.这样每个sub-layer的输出就是 \(LayerNorm(x+Sublayer(x))\)

    Decorder: decorder和encorder结构类似, 除了每个block多了一个sub-layer, 多出来的一个sub-layer还要将encorder模块的输出作为输入. 如上图中部所示.剩下的那个Multi-Head Attention layer要添加mask,这是为了确保position-i的预测输出仅仅依赖之前position的预测输出.

下面详细了解一下Attention 机制

    Attention 机制就是将 a query 和 a set of key-value pairs 映射为输出. 输出是这些values的加权平均, 加权系数是the query 和 这些keys计算得到的.
在机器翻译中,我的理解是key-value pairs 建立起比如中文和英文之间的全体映射关系.这个映射关系的好坏直接决定了模型性能. 比如现在从英文翻译中文. key对应着英文,value对应这中文,query自然是英文了.下面的Attention机制是计算query与keys的内积取softmax后作为values的加权系数. 内积表示query与keys的相似度.

比如, I love you. 翻译为我爱你. 有映射关系集(a set of key-values pairs){I->我, love->爱, you->你, ...}(->前面是key,后面是value), 查询向量{I, love, you}.key-values对应的keys集为{I, love, you, ...}, 对应的values集为{我,爱,你, ...}
分别计算查询向量与keys集的相似度, I:{1,0,0, ...}, love:{0,1,0, ...}, you:{0,0,1,...}
查询向量与keys的相似度,也可以认为是与values的相似度,于是加权平均就得到了我,爱,你.

下面再正式梳理一下Attention机制

    Scaled Dot-Product Attention 其输入包括\(d_k\)维的queries和keys,和\(d_v\)维的values.因为queries和keys在同一域中,所以维度是一样的.正如论文最开始就点明的,Attention机制,就是计算query与所有的keys的内积作为加权系数作用于values,为了消除个别内积过大导致梯度过小,内积之后要除以\(\sqrt k\). 众多queries可以一起计算.千言万语都在一个公式中

\(Attention(Q,K,V)=softmax(\frac{QK^T}{\sqrt k})V\)

Multi-Head Attention 作者实践发现下面做效果更好. 用不同的线性映射将queries,keys,values映射为新的维度一致的queries,keys,values,然后再经过Scaled Dot-Product Attention 之后得到的 outputs聚合在一起再映射得到最终的输出.公式就不再写了.

这两个Attention如下图所示

attention.png

    再回头看看本文的encoreder-decorder结构, 两种Attention机制输入包含3个Q,K,V.图1中左侧的模块中,Q,K,V都是相同的,来自其上一层; 右侧模块中queries来自上一个decorder layer的输出, 而keys,values则来自encorder模块的输出.

猜你喜欢

转载自www.cnblogs.com/wolfling/p/9419810.html