大语言模型之十一 Transformer后继者Retentive Networks (RetNet)

在《大语言模型之四-LlaMA-2从模型到应用》的LLama-2推理图中可以看到,在输入“你好!”时,是串行进行的,即先输入“你”这个token,然后是“好”,再然后是“!”token,前一个token需要保留前面的k和v矩阵,这就意味着随着输入sequence length的增长,需要的内存也会快速增长,计算量也会快速增长。这也显示了Transformer尽管在模型训练的时候并发(相比RNN)性能好,且模型的效果也好,但是推理的时候效率就比较低。

RetNet特点

微软提出的RetNet在训练并发、模型效果以及推理效率上都取得了不错的效果。下图是其paper中关于模型性能和推理效率和Transformer的对比情况。
在这里插入图片描述
其官方paper宣称,其实验数据显示,在语言建模任务上:
RetNet 可以达到与 Transformer 相当的困惑度(perplexity)
推理速度达8.4倍
推理算法延迟降低90%
内存占用减少70%
具有良好的扩展性
实现以上改进是因为RetNet 在 Transformer 的基础上,使用多尺度保持(Retention)机制替代了标准的自注意力机制。

与标准自注意力机制相比,保持机制有几大特点:

  • 引入位置相关的指数衰减项取代 softmax,简化了计算,同时使前步的信息以衰减的形式保留下来。
  • 引入复数空间表达位置信息,取代绝对或相对位置编码,容易转换为递归形式。
  • 保持机制使用多尺度的衰减率,增加了模型的表达能力,并利用 GroupNorm 的缩放不变性来提高 Retention 层的数值精度。

训练并发

因为Transformer采用了self-attention机制,每一阶段的输出都可以用Q,K,V进行并发处理,这大大提高了GPU利用率,提高了训练效率。RNN网络的好处是推理的效率高(相比开篇提的KV历史是不需要保留的,这使得计算量和内存都极大减少了),内存复杂度低O(1)。
RetNet的巧妙之处在于,训练的时候依然类似Transformer的并发结构(如图3左边),而在推理的时候则可以采用RNN的结构(如图3右边)。即实现了parallel training, recurrent/chunk-wise inference。

RetNet改进点

RetNet相比于Transformer主要有两点改进:
1.引入multi-scale retention替代了multi-head attention;
2.RetNet可以用三种(parallel/recurrent/chunk-wise )方式实现,公式和结果上是相等的,因而可以在训练和推理的时候选择最为高效的方式实现;chunk-wise可以更高效的处理长sequence的序列情况。
在这里插入图片描述

并行训练

在这里插入图片描述
RetNet摒弃了softmax操作,引入了基于D矩阵的Hadamard积(对应元素相乘),然后是GroupNorm操作。Transformer中softmax为输入序列中的每个token提供了相对的注意力权重,有助于模型学习和保留长期依赖关系。之前也有一些研究是舍弃softmax运算,但是会降低模型性能。
具体来说Transformer中的softmax主要实现了两个目标:
1.对不同的时间步长采用不同的方式加权,这有助于模型注意力放在序列中应该感兴趣的部分,这是相比RNN而言最重要的贡献。RetNet论文中的D矩阵实现注意力机制,D矩阵是一个因果注意力矩阵,即序列的当前输入只能看到过去的信息,D-矩阵假设最近的时间步长比过去的时间步长更重要,因此采用了指数衰减权重。因此,softmax足够灵活,可以对不同的步骤做不同的权重估计,而D-矩阵以固定的预定义方式(指数衰减)权衡所有步骤。最终paper里的结果显示RetNet效果是比Transformer好的。
2.引入非线性,当没有softmax的时候, Q K T QK^T QKT就是一种仿射变换(即从一种空间的标识变为另一种空间的表示),再多层的注意力堆叠也依然是一种仿射变换,非线性是通过GroupNorm的方式实现的,为什么是GroupNorm实现非线性,论文中并没有提及,似乎是盲测多种结构在GroupNorm时效果是最好的。

Transformer和RetNet的异同

这里我将博客《大语言模型之四-LlaMA-2从模型到应用》中LlaMA-2的推理流程图展示在这里了,该图的详细说明见博客。
在这里插入图片描述
图中第一步是用 W q , W v , W k W_q, W_v, W_k Wq,Wv,Wk得到 Q , V , 和 K Q,V,和K Q,V,K,即 Q = X W q Q=XW_q Q=XWq K = X W k K=XW_k K=XWk V = X W v V=XW_v V=XWv,然后通过softmax得到Attention score
o = s o f t m a x ( Q K T d k ) V ( 1 ) o= softmax (\frac{Q K^T}{\sqrt{d_k}})V (1) o=softmax(dk QKT)V(1)
由于RetNet在循环网络结构和并行网络结构都可以得到一样的结果,所以论文的作者现在循环结构中实现激励的“保留”模块,然后将“保留”模块向量化,因此得到的第n个序列输入的结果(保留注意力)为:
o n = ∑ m = 1 n ( Q n A n − m K m T ) v m , Q n ∈ R 1 × d ( 2 ) o_n=\sum_{m=1}^n(Q_nA^{n-m}K_m^T)v_m, Q_n \in \mathbb{R}^{1\times d} (2) on=m=1n(QnAnmKmT)vm,QnR1×d(2)
从上面公式1和2可以看到RetNet网络使用pos矩阵 A n − m A^{n-m} Anm替换掉了Transformer中的softmax的Attention部分(非线性部分这里还没替换掉)。Retention score的想法是这样的,首先通过循环网络状态 s ( n ) \mathbf s(n) s(n)做映射 v ( n ) → o ( n ) v(n) \rightarrow o(n) v(n)o(n),即
s n = A s n − 1 + K n T v n , A ∈ R d × d , K n ∈ R 1 × d ( 3 ) \mathbf s_n = A \mathbf s_{n-1} +K_n^Tv_n, A \in \mathbb R^{d\times d}, K_n \in \mathbb R^{1 \times d} (3) sn=Asn1+KnTvn,ARd×d,KnR1×d(3)
然后使用线性变换递归对序列进行编码:
o n = Q n s n = ∑ m = 1 n ( Q n A n − m K m T ) v m , Q n ∈ R 1 × d ( 4 ) o_n=Q_n \mathbf s_n=\sum_{m=1}^n(Q_nA^{n-m}K_m^T)v_m, Q_n \in \mathbb{R}^{1\times d}(4) on=Qnsn=m=1n(QnAnmKmT)vm,QnR1×d(4)
根据原论文的公式3,可以将位置分为共轭的两个部分。
R e t e n t i o n ( x ) = ∑ m = 1 n ( Q n ( γ e i θ ) n ) ( K m ( γ e i θ ) − m ) T v m , γ , θ ∈ R d ( 5 ) Retention(x) = \sum_{m=1}^n(Q_n(\gamma e^{i \theta})^{n})(K_m(\gamma e^{i\theta})^{-m})^Tv_m, \gamma, \theta \in \mathbb R^d (5) Retention(x)=m=1n(Qn(γeiθ)n)(Km(γeiθ)m)Tvm,γ,θRd(5)
其中 Q n ( γ e i θ ) n Q_n(\gamma e^{i \theta})^{n} Qn(γeiθ)n K m ( γ e i θ ) − m K_m(\gamma e^{i\theta})^{-m} Km(γeiθ)m是位置矩阵,可以采用论文所述的xPos,为了简化上述的方程,可以用标量值替代 γ \gamma γ,这样在训练的时候可以采用并发的方式训练:
在这里插入图片描述
这样可以看到在RetNet时候,在得到Q、K以及V的方法和原始的Transformer一样是可以并发进行的, e i n θ e^{in\theta} einθ对应的位置信息逐点相乘即可。最后一步的Retention score使用到的D矩阵也是可以提前计算的,因为它只是一个相对位置嵌入+因果掩码。

在这里插入图片描述
图 3 RetNet的两种实现方式
公式6对应的就是图3左边的实现方式,在推理的时候采用循环网络的架构,即图3中右边的实现方式,公式如下7所示。
S n = γ S n − 1 + K n T V n , R e t e n t i o n ( X n ) = Q n S n , n = 1 , ⋯   , ∣ x ∣ ( 7 ) S_n=\gamma S_{n-1}+K_n^TV_n , Retention(X_n)=Q_nS_n, n=1,\cdots, |x| (7) Sn=γSn1+KnTVn,Retention(Xn)=QnSn,n=1,,x7
其中的Q、K、V以及 γ \gamma γ的作用和意义和上式6是一样的。

RetNet并行计算过程

假设仅有“你好”这两个token输入序列,长度记为N,Embedding dim,D=3,则得到的QKV是矩阵的维度是NxD,假设初始的矩阵为:
Q = [ 1 2 1 3 2 3 ] , K = [ 1 2 3 4 5 6 ] , V = [ 5 4 3 2 1 0 ] Q=\begin{bmatrix} 1 & 2 & 1 \\ 3 & 2 &3 \end{bmatrix},K=\begin{bmatrix} 1 & 2 & 3 \\ 4 & 5 &6 \end{bmatrix},V=\begin{bmatrix} 5 & 4 & 3 \\ 2 & 1 &0 \end{bmatrix} Q=[132213],K=[142536],V=[524130]
第一步:计算 Q K T QK^T QKT
Q K T = [ 1 2 1 3 2 3 ] [ 1 4 2 5 3 6 ] = [ 8 20 16 40 ] ( 8 ) QK^T=\begin{bmatrix} 1 & 2 & 1 \\ 3 & 2 &3 \end{bmatrix} \begin{bmatrix} 1 & 4 \\ 2 & 5 \\ 3 & 6 \end{bmatrix}=\begin{bmatrix} 8 & 20 \\ 16 & 40 \end{bmatrix} (8) QKT=[132213] 123456 =[8162040]8
第二步:计算 Q K T QK^T QKT D D D矩阵的Hadamard积(对应元素相乘)
Q K T ⊙ D = [ 8 20 16 40 ] ⊙ [ 1 0 0.25 1 ] = [ 8 4 0 40 ] ( 9 ) QK^T\odot D=\begin{bmatrix} 8 & 20 \\ 16 & 40 \end{bmatrix} \odot \begin{bmatrix} 1 & 0 \\ 0.25 & 1 \end{bmatrix}= \begin{bmatrix} 8 & 4 \\ 0& 40 \end{bmatrix}(9) QKTD=[8162040][10.2501]=[80440]9
其中 D = [ γ 1 − 0 γ 1 − 1 γ 2 − 0 γ 2 − 1 ] D=\begin{bmatrix} \gamma^{1-0} & \gamma^{1-1} \\ \gamma^{2-0} & \gamma^{2-1} \end{bmatrix} D=[γ10γ20γ11γ21],当 γ = 0.5 \gamma=0.5 γ=0.5时可以得到式9。
第三步和V相乘:
( Q K T ⊙ D ) V = [ 8 4 0 40 ] [ 5 4 3 2 1 0 ] = [ 40 32 24 100 56 12 ] ( 10 ) (QK^T \odot D)V = \begin{bmatrix} 8 & 4 \\ 0& 40 \end{bmatrix} \begin{bmatrix} 5 & 4 & 3\\ 2& 1 & 0 \end{bmatrix} = \begin{bmatrix} 40 & 32 & 24\\ 100& 56 & 12 \end{bmatrix} (10) (QKTD)V=[80440][524130]=[4010032562412](10)
这样就得到了两个token输入时的最终上下文Embedding结果。

RetNet循环网络计算过程

图3的右侧所示过程。这里的Q、K、V和上面的并行计算并不是一样的,这里有下标n,这表明了是第n个输入token对应的矩阵,因而是一个1xD维矩阵(上一小节是NxD维),另外一个区别是含有当前token之前时间和位置信息的状态S,当前状态和前一个状态使用指数衰减因子,如公式7所示。
在计算的时候先是KV不再是QK相乘,图3中可以看到,和并行计算一样,初始的矩阵为
Q = [ 1 2 1 3 2 3 ] , K = [ 1 2 3 4 5 6 ] , V = [ 5 4 3 2 1 0 ] Q=\begin{bmatrix} 1 & 2 & 1 \\ 3 & 2 &3 \end{bmatrix},K=\begin{bmatrix} 1 & 2 & 3 \\ 4 & 5 &6 \end{bmatrix},V=\begin{bmatrix} 5 & 4 & 3 \\ 2 & 1 &0 \end{bmatrix} Q=[132213],K=[142536],V=[524130]
第一步计算 K 1 T ⊗ V 1 K_1^T\otimes V_1 K1TV1
K 1 T V 1 = [ 1 2 3 ] [ 5 4 3 ] = [ 5 4 3 10 8 6 15 12 9 ] ( 11 ) K_1^T V_1 = \begin{bmatrix} 1 \\ 2 \\ 3 \end{bmatrix} \begin{bmatrix} 5 & 4 & 3\end{bmatrix} = \begin{bmatrix} 5 & 4 & 3\\ 10& 8 & 6 \\ 15 & 12 &9 \end{bmatrix} (11) K1TV1= 123 [543]= 510154812369 (11)
第二步计算 S 1 S_1 S1
因为在输入“你”token之前并没有其他token输入,因而 S 0 S_0 S0是不存在的,因而先前没有状态叠加到“你”这个token对应的状态的。
S 1 = γ 0 S 0 + K 1 T V 1 = [ 5 4 3 10 8 6 15 12 9 ] ( 12 ) S_1=\gamma^{0}S_0 + K_1^TV_1=\begin{bmatrix} 5 & 4 & 3\\ 10& 8 & 6 \\ 15 & 12 &9 \end{bmatrix} (12) S1=γ0S0+K1TV1= 510154812369 (12)
第三步将Q和 S 1 S_1 S1相乘得到最终的Attention score。
Q 1 ⊗ S 1 = [ 5 4 3 10 8 6 15 12 9 ] ⊗ [ 1 2 1 ] = [ 5 4 3 20 16 12 15 12 9 ] s u m = [ 40 32 24 ] ( 13 ) Q_1 \otimes S_1 = \begin{bmatrix} 5 & 4 & 3\\ 10& 8 & 6 \\ 15 & 12 &9 \end{bmatrix} \otimes \begin{bmatrix} 1 \\ 2 \\ 1 \end{bmatrix} = \begin{bmatrix} 5 & 4 & 3\\ 20& 16 & 12 \\ 15 & 12 &9 \end{bmatrix}_{sum} = \begin{bmatrix} 40 & 32 & 24\end{bmatrix} (13) Q1S1= 510154812369 121 = 52015416123129 sum=[403224](13)
这和公式10最终结果的第一行是一样的。
第四步计算 K 2 T V 2 K_2^TV_2 K2TV2,和第一步类似,得到:
K 2 T V 2 = [ 4 5 6 ] [ 2 1 0 ] = [ 8 4 0 10 5 0 12 6 0 ] ( 14 ) K_2^T V_2 = \begin{bmatrix} 4 \\ 5 \\ 6 \end{bmatrix} \begin{bmatrix} 2 & 1 & 0\end{bmatrix} = \begin{bmatrix} 8 & 4 & 0\\ 10& 5 & 0 \\ 12 & 6 & 0 \end{bmatrix} (14) K2TV2= 456 [210]= 81012456000 (14)
第五步计算 S 2 S_2 S2
S 2 = γ 2 S 1 + K 2 T V 2 = 0. 5 2 [ 5 4 3 20 16 12 15 12 9 ] + [ 8 4 0 10 5 0 12 6 0 ] = [ 9.25 0.25 0.75 12.5 7 1.5 15.5 9 2.25 ] ( 15 ) S_2=\gamma^{2}S_1 + K_2^TV_2=0.5^2\begin{bmatrix} 5 & 4 & 3\\ 20& 16 & 12 \\ 15 & 12 & 9 \end{bmatrix} + \begin{bmatrix} 8 & 4 & 0\\ 10& 5 & 0 \\ 12 & 6 & 0 \end{bmatrix} = \begin{bmatrix} 9.25 & 0.25 & 0.75\\ 12.5& 7 & 1.5 \\ 15.5 & 9 & 2.25 \end{bmatrix} (15) S2=γ2S1+K2TV2=0.52 52015416123129 + 81012456000 = 9.2512.515.50.25790.751.52.25 (15)
第六步计算最终RetNet score。
Q 2 ⊗ S 2 = [ 9.25 0.25 0.75 12.5 7 1.5 15.5 9 2.25 ] ⊗ [ 3 2 3 ] ( 16 ) = [ 27.25 15 2.25 25 14 3 47.25 27 6.75 ] s u m = [ 100 56 12 ] ( 16 ) Q_2 \otimes S_2 = \begin{bmatrix} 9.25 & 0.25 & 0.75\\ 12.5& 7 & 1.5 \\ 15.5 & 9 & 2.25 \end{bmatrix} \otimes \begin{bmatrix} 3 \\ 2 \\ 3 \end{bmatrix}(16) =\begin{bmatrix} 27.25 & 15 & 2.25 \\ 25& 14 & 3 \\ 47.25 & 27 & 6.75 \end{bmatrix}_{sum} = \begin{bmatrix} 100 & 56 & 12\end{bmatrix} (16) Q2S2= 9.2512.515.50.25790.751.52.25 323 (16)= 27.252547.251514272.2536.75 sum=[1005612](16)
这样公式16的结果正好等于公式10,从公式13和16结果看和并行计算的结果是一样的,但是这中间并没有太多的存储空间需求,也没有非线性softmax计算。

猜你喜欢

转载自blog.csdn.net/shichaog/article/details/133049016