self-attention与Transformer补充

self-attention与Transformer

参考
整体介绍
  • 之前有记过Transformer的笔记,但是属于一知半解的情况,现在看了李宏毅老师的课,对self- attention有了更深的理解,作为之前笔记的补充

  • 首先,self-attention在ML中一般是以这样的方式存在

内部结构
  • 而对于输出 b 1 − b 4 b_1 - b_4 b1b4,self-attention会考虑每个输入 a 1 − a 4 a_1 - a_4 a1a4

  • 具体的方式是:寻找 a 1 − a 4 a_1 - a_4 a1a4之间的相关性:

  • 相关性对计算方式如下

    • 任意两个输入 a i a j a_i \quad a_j aiaj,对输出 b i b_i bi会先取得 a i a_i ai的q矩阵和 a j a_j aj的k矩阵( q = W q ∗ a i q=W_q*a_i q=Wqai k = W k ∗ a j k=W_k * a_j k=Wkaj,其中W为训练出来的权重参数(就类似于CNN中Kernel中的参数一样))

      • 如果使用Dot- product方法,则对qk进行点积即可

      • 如果是additive方法,则计算 W ∗ t a n h ( q + k ) W*tanh(q+k) Wtanh(q+k),其中tanh作用是将范围限制在[-1,1]之间

        • 在这 里插入图片描述
  • 根据这个方法我们可以得到( a 1 a_1 a1为例)关于a1的所有相关性的值,再对所有值做一个softmax就可以得到这些值的score

  • 接下来就可以计算b的值: b 1 = ∑ i a 1 , i ′ v i b^1= \sum_{i} a_{1,i}^{'}v^i b1=ia1,ivi其中 v i v^i vi= W v ∗ a i W^v*a_i Wvai

  • 有一点需要注意的是: W q W k W v W^q \quad W^k \quad W^v WqWkWv三个矩阵是所有输入值共享的,一个self-attention结构共享一个

  • 因此,根据上面的式子,我们可以将输入的合并成一个矩阵计算,就可以实现一次计算得到所有的相关性矩阵A( A ′ A^{'} A是做了个softmax),以及输出矩阵O

  • 于是从外部看就有了以下的形式,也就符合这个公式: A t t e n t i o n ( Q , K , V ) = s o f t m a x ( Q K T d k ) V Attention(Q,K,V) = softmax(\frac{QK^T}{\sqrt{d_k}})V Attention(Q,K,V)=softmax(dk QKT)V,公式中的矩阵顺序颠倒,但不影响理解,其中除以 d k \sqrt d_k d k是为了保证方差为1,有利于梯度下降

multi-head self-attention
  • 其实就可以理解为同一套输入进入多套并行的self-attention模块里面去,对最后的结果进行cancat一下即可

    • M u l t i H e a d ( Q , K , V ) = C o n c a t ( h e a d 1 , . . . . , h e a d h ) W o MultiHead(Q,K,V) = Concat(head_1,....,head_h)W^o MultiHead(Q,K,V)=Concat(head1,....,headh)Wo

    • h e a d i = A t t e n t i o n ( Q W i Q , K W i K , V W i V ) head_i = Attention(QW_{i}^Q,KW_{i}^K,VW_{i}^V) headi=Attention(QWiQ,KWiK,VWiV)

回过头来看Transformer
  • 回过头看看Transformer的结构,会发现豁然开朗

  • image-20210607155256584
encoder部分
  • 首先对于encoder部分:

    • image-20210607161853853
    • 我们对输入编码后加上位置信息编码(Positional Encoding)后就进入Multi- Head Attention(参考上面的理解),之后做了一个Add&NORM的操作,其中Add代表的是残差,参考RenNet中残差的作用,是防止网络退化的;而Norm则是做了这么一个操作(归一化,但不用考了batch,因此比Batch Norm来的简单)

    •  image-20210806105125604
    • 而在往下的Feed Forward实际上可以为FC层(当然别的也可以),然后再来一组残差+归一化,就是整个encoder的结构了

decoder部分
  • image -20210806105720932
  • 对于decoder,首先对输出考虑Masked版本的多头注意力机制,其实就是在计算相关性的时候可以考虑上文(之前的)但不能考虑下文(之后的),然后再与encoder的输出汇合,再做一组多头注意力机制的前推,残差归一化+FC,就完成了decoder的计算

  • 这个跟encoder十分类似

Guess you like

Origin blog.csdn.net/symuamua/article/details/119448154