预训练语言模型 | (2) transformer

原文链接

目录

1. 背景

2. transformer流程与技术细节

3. 总结


1. 背景

17年之前,语言模型都是通过rnn,lstm来建模,这样虽然可以学习上下文之间的关系,但是无法并行化,给模型的训练和推理带来了困难,因此论文提出了一种完全基于attention来对语言建模的模型,叫做transformer。transformer摆脱了nlp任务对于rnn,lstm的依赖,使用了self-attention的方式对上下文进行建模,提高了训练和推理的速度,transformer也是后续更强大的nlp预训练模型的基础,因此有必要花很大的篇幅详解一下这个模型。

transformer中的self-attention是从普通的点积attention中演化出来的,演化过程中可以看:遍地开花的 Attention ,你真的懂吗?

 

2. transformer流程与技术细节

<1> Inputs是经过padding的输入数据,大小是[batch size, max seq length]。

<2> 初始化embedding matrix,通过embedding lookup将Inputs映射成token embedding,大小是[batch size, max seq length, embedding size],然后乘以embedding size的开方。

乘以embedding size开方的原因是:猜测是因为embedding matrix的初始化方式是xavier init,这种方式的方差是1/embedding size,因此乘以embedding size的开方使得embedding matrix的方差是1,在这个scale下可能更有利于embedding matrix的收敛。

<3> 通过sin和cos函数创建positional encoding,表示一个token的绝对位置信息,并加入到token embedding中,然后dropout。

inputs/token embedding加入positional encoding原因:因为self-attention是位置无关的,无论句子的顺序是什么样的,通过self-attention计算的token的hidden embedding都是一样的,这显然不符合人类的思维。因此要有一个办法能够在模型中表达出一个token的位置信息,transformer使用了固定的positional encoding来表示token在句子中的绝对位置信息。positional encoding的公式如下:

至于positional encoding为什么能表示位置信息,可以看如何理解Transformer论文中的positional encoding,和三角函数有什么关系?

<4> multi-head attention

<4.1> 输入token embedding,分别通过三个Dense层生成Q,K,V,大小都是[batch size, max seq length, embedding size],然后按第2维(embedding_size维)split成num heads份并按第0维concat(多个注意力头并行计算),生成新的Q,K,V,大小是[num heads*batch size, max seq length, embedding size/num heads],完成multi-head的操作。

<4.2> 将K的第1维和第2维进行转置,然后Q和转置后的K的进行点积(QK^T),结果的大小是[num heads*batch size, max seq length, max seq length]。(包括每一个token查询向量和其他所有token的键向量的内积结果)

<4.3> 将<4.2>的结果除以hidden size的开方(在transformer中,hidden size=embedding size),完成scale的操作。

scale的原因:以数组为例,2个长度是len,均值是0,方差是1的数组点积会生成长度是len,均值是0,方差是len的数组。而方差变大会导致softmax的输入推向正无穷或负无穷,这时的梯度会无限趋近于0,不利于训练的收敛。因此除以len的开方,可以是数组的方差重新回归到1,有利于训练的收敛。

<4.4> 将<4.3>中padding部分token的点积结果置成一个很小的数(-2^32+1),完成mask操作,后续softmax对padding的结果就可以忽略不计了。

<4.5> 将经过mask的结果进行softmax操作。

<4.6> 将softmax的结果和V进行点积,得到attention的结果,大小是[num heads*batch size, max seq length, hidden size/num heads]。

<4.7> 将attention的结果按第0维split成num heads份并按第2维concat,生成multi-head attention的结果,大小是[batch size, max seq length, hidden size]。Figure 2上concat之后还有一个linear的操作,但是代码里并没有。

为什么attention需要multi-head,一个大head行不行?

multi-head相当于把一个大空间划分成多个互斥的小空间,然后在小空间内分别计算attention,虽然单个小空间的attention计算结果没有大空间计算得精确,但是多个小空间并行计算attention然后结果concat有助于网络捕捉到更丰富的信息,类比cnn网络中的channel。

<5> 将token embedding和multi-head attention的结果相加,并进行Layer Normalization(二者维度相同)。

multi-head attention的输入和输出相加的原因:类似于resnet中的残差学习单元(残差连接,shortcut,缓解梯度消失),有ensemble的思想在里面,解决网络退化问题。

<6> 将<5>的结果经过2层Dense(全连接层),其中第1层的activation=relu,第2层activation=None。

为什么multi-head attention后面要加一个ffn?

类比cnn网络中,cnn block和fc交替连接,效果更好。相比于单独的multi-head attention,在后面加一个ffn,可以提高整个block的非线性变换的能力。

<7> 功能和<5>一样。

<8> Outputs是经过padding的输出数据,与Inputs不同的是,Outputs的需要在序列前面加上一个起始符号”<s>”,用来表示序列生成的开始,而Inputs不需要。(batch_size,max seq length)

<9> 功能和<2>一样。

<10> 功能和<3>一样。

<11> 功能和<4>类似,唯一不同的一点在于mask,<11>中的mask不仅将padding部分token的点积结果置成一个很小的数,而且将当前token(查询向量)与之后的token(键向量)的点积结果也置成一个很小的数。

为什么要mask当前时刻的token与后续token的点积结果?

自然语言生成(例如机器翻译,文本摘要)是auto-regressive的,在推理的时候只能依据之前的token生成当前时刻的token,正因为生成当前时刻的token的时候并不知道后续的token长什么样,所以为了保持训练和推理的一致性,训练的时候也不能利用后续的token来生成当前时刻的token。这种方式也符合人类在自然语言生成中的思维方式。

<12> 功能和<5>一样。

<13> 功能和<4>类似,唯一不同的一点在于Q,K,V的输入,<13>的Q的输入来自于Outputs 的token embedding,<13>的K,V来自于<7>的结果(分别经过两个Dense层)。

<14> 功能和<5>一样。

<15> 功能和<6>一样。

<16> 功能和<7>一样,结果的大小是[batch size, max seq length, hidden size]。

<17> 将<16>的结果的后2维和embedding matrix的转置进行点积,生成的结果的大小是[batch size, max seq length, vocab size]。

<18> 将<17>的结果进行softmax操作,生成的结果就表示当前时刻预测的下一个token在vocab上的概率分布。

<19> 计算<18>得到的下一个token在vocab上的概率分布和真实的下一个token的one-hot形式的cross entropy,然后sum非padding的token的cross entropy当作loss,利用adam进行训练。

 

3. 总结

在小模型上self-attention并不比rnn,lstm好。直到大力出奇迹的bert出现,当模型变得越来越大,样本数越来越多的时候,self-attention无论是并行化带来的训练提速,还是在长距离上的建模,都是要比传统的rnn,lstm好很多。transformer现在已经是各种具有代表性的nlp预训练模型的基础,bert系列预训练模型使用了transformer的encoder,gpt系列预训练模型使用了transformer的decoder。在推荐领域,transformer的multi-head attention也应用得很广泛。

 

 

发布了405 篇原创文章 · 获赞 765 · 访问量 14万+

猜你喜欢

转载自blog.csdn.net/sdu_hao/article/details/104189383