Preserving Networks: A Transformer successor for large language models

Original information

Original title: "Retentive Network: A Successor to Transformer for Large Language Models"

原文引用:Sun Y, Dong L, Huang S, et al. Retentive Network: A Successor to Transformer for Large Language Models[J]. arXiv preprint arXiv:2307.08621, 2023.

Original link:https://arxiv.org/pdf/2307.08621.pdficon-default.png?t=N6B9https://arxiv.org/pdf/2307.08621 .pdf

0. Summary

        In this work, we propose Retentive Network (RETNET) as an infrastructure for large-scale language models, while achieving training parallelism, low-cost inference, and good performance. We theoretically derive the connection between looping and attention. Then, we propose a preservation mechanism for sequence modeling that supports three computing paradigms, namely parallel, recursive, and block-recursive. Specifically, parallel representation allows for training parallelization. The recursive representation enables low-cost O(1) inference, improving decoding throughput, latency, and GPU memory without sacrificing performance. The chunked recursive representation enables efficient long sequence modeling with linear complexity, where each chunk is encoded in parallel while summarizing the chunk recursively. Experimental results on language models show that RETNET achieves favorable scaling results, parallel training, low-cost deployment, and efficient inference. These interesting properties make RETNET a strong successor to Transformer in large language models. Code will be available at https://aka.ms/retnet.

 Figure 1: Compared with Transformer, Retention Network (RetNet) achieves low-cost inference, training parallelism and good scaling curves in terms of GPU memory, throughput and latency. Results of inference costs are reported with an input length of 8k. Figure 6 shows more results for different sequence lengths.

1 Introduction

        Transformer [VSP+17] has become the de facto architecture for large language models [BMR+20] and was originally proposed to solve the sequential training problem of recurrent models [HS97]. However, Transformer’s training parallelism comes at the cost of inefficient inference, as each step has a complexity of O(N) and is limited to the in-memory key-value cache [Sha19], which makes Transformer unfriendly in terms of deployment. . Growing sequence length increases GPU memory consumption, while also increasing latency and slowing down inference.

        Many efforts continue to be devoted to developing next-generation architectures aimed at maintaining training parallelism and competitive performance comparable to Transformer, while having efficient O(1) inference. It is challenging to achieve the above goals simultaneously, the so-called "impossible triangle", as shown in Figure 2.

        There have been three main research directions. First, linearized attention [KVPF20] uses the kernel function ϕ(q) · ϕ(k) to approximate the standard attention score exp(q · k) in order to rewrite the autoregressive inference into a recurrent form. However, the modeling ability and performance of the model are worse than those of the Transformer, which hinders the popularity of this method. The second direction is regression models for efficient inference but at the expense of training parallelism. To compensate for this, element-wise operators [PAA+23] are used for speedup, but at the expense of representation capabilities and performance. The third research direction is to explore the use of other mechanisms to replace attention, such as S4 [GGR21] and its variants [DFS+22, PMN+23]. However, none of the previous works have been able to break through the "impossible triangle", and there is no clear winner compared to Transformer.

        ​​​​​ In this work, we propose a retention network (RetNet) that enables low-cost inference, efficient long sequence modeling, comparable performance to Transformer, and parallel model training. Specifically, we introduce a multi-scale preserving mechanism to replace multi-head attention with three computational paradigms, namely parallel, recurrent and chunked recurrent representations. First, through parallel representation, we are able to fully exploit GPU devices for training parallelism. Second, through loop representation, we are able to achieve efficient O(1) reasoning in terms of memory and computation, thus greatly reducing deployment cost and latency. In addition, there is no need to use key-value caching techniques during the implementation process, which greatly simplifies the implementation. Third, the chunked loop representation enables efficient long sequence modeling. We simultaneously encode each local block in parallel to increase computational speed, while loop encoding global blocks to save GPU memory.

        We conducted extensive experiments comparing RetNet with Transformer and its variants. Experimental results in language modeling show that RetNet is consistently competitive in scaling curves and contextual learning. Furthermore, the inference cost of RetNet is independent of sequence length. For a 7B model and 8k sequence length, RetNet decodes 8.4 times faster than Transformer with key-value cache and saves 70% of memory. During the training process, RetNet saves 25-50% of memory compared to the standard Transformer and is 7 times faster than the highly optimized FlashAttention [DFE+22]. Furthermore, RetNet’s inference latency is insensitive to batch size and can achieve huge throughput. These interesting properties make RetNet a powerful successor to Transformer in large language models.

Figure 2: RetNet makes the "impossible triangle" possible, achieving the simultaneous existence of training parallelism, good performance and low inference cost.

2. Reserve the network

        The Retention Network (RetNet) is stacked by L identical blocks with a layout similar to that in Transformer [VSP+17] (i.e. residual connections and pre-LayerNorm). Each RetNet block contains two modules: multi-scale preserving (MSR) module and feed-forward network (FFN) module. We introduce the MSR module in the following sections. Given an input sequence x = x1 · · · x|x|, RetNet encodes the sequence in an autoregressive manner. First, the input vector {xi} |x| i=1 is packed into X0 = [x1, · · · , x|x|] ∈ R |x|×dmodel, where dmodel is the hidden dimension. We then compute the contextualized vector representation Xl = RetNetl(Xl−1), where l ∈ [1, L].

2.1.Retention mechanism

        In this section, we introduce a retention mechanism with looping and parallelism. Therefore, we can train the model in a parallel manner while performing inference in a loop. Given an input X ∈ R |x|×dmodel, we project it to a one-dimensional function v(n) = Xn · wV. Consider a sequence modeling problem of mapping v(n) to o(n), where the mapping is done by state sn. To simplify the notation, we use vn and on to represent v(n) and o(n). We express this mapping in a cycle:         Among them, † indicates that the co -storage is converted. This formulation is easily parallelizable within training instances. To summarize, we start from the loop modeling as shown in Equation (1), and then derive the parallel formula as shown in Equation (4). We treat the original mapping v(n) 7→ o(n) as a vector and get the preservation mechanism as follows.

        Where Θ is the complex conjugate of Θ, D ∈ R |x|×|x| combines the causal mask and the exponential decay of relative distance into a matrix. Similar to self-attention, this parallel representation enables us to efficiently train models using GPUs.

        Preserved Recurrent Representation As shown in Figure 3b, the proposed mechanism can also be represented as a Recurrent Neural Network (RNN), which is beneficial for reasoning. For the nth time step, we loop through the output as:

        Preserved Blocked Recurrent Representation For accelerated training, especially for long sequences, a hybrid form of parallel and recurrent representations can be used. We divide the input sequence into chunks. Within each block, we perform calculations following the parallel representation (Equation (5)). Instead, information across blocks is passed according to the cyclic representation (Equation (6)). Specifically, let B denote the length of the block. We calculate the retained output of the i-th block by:

Figure 3: Dual form of RetNet. "GN" is the abbreviation of GroupNorm.

2.2. Gated multi-scale retention

        We retain the head in each layer using h = dmodel/d, where d is the dimension of the head. These heads use different parameter matrices WQ, WK, WV ∈ R d×d. Furthermore, Multi-Scale Retention (MSR) assigns a different γ to each head. For simplicity, we make γ the same between different layers and keep it fixed. Furthermore, we add a swish gate [HG16, RZL17] to increase the nonlinearity of the preserving layer. Formally, given an input X, we define the layer as:

        Where WG, WO ∈ R dmodel×dmodel are learnable parameters, GroupNorm [WH18] normalizes the output of each head, following the SubLN proposed in [SPP+19]. Note that the head uses multiple gamma scales, which results in different variance statistics. Therefore, we normalize the head output separately. A summary of the retained pseudocode is shown in Figure 4.

        Normalization of retained fractions We exploit the scale invariance of GroupNorm to improve the numerical accuracy of the retained layer. Specifically, multiplying a scalar value in GroupNorm does not affect the output and inverse gradient, i.e. GroupNorm(α ∗ headi) = GroupNorm(headi). We implement three normalization factors in Equation (5). First, we normalize QK⊺ to QK⊺/√d. Second, we replace D with D˜nm = Dnm/√Pn i=1 Dni. Third, let R represent the retention fraction R = QK⊺ ⊙ D, and we normalize it as R˜nm = Rnm/max(|Pn i=1 Rni|,1). Then, the retention output becomes Retention(X) = RV˜. The above technique does not affect the final result while stabilizing the flow of values ​​in forward and backward passes due to the scale-invariant property.

 Figure 4: Pseudocode for the three computational paradigms retained.

2.3. Retain the overall architecture of the network

        For an L-layer preserving network, we stack multi-scale preserving (MSR) and feed-forward networks (FFN) to build the model. Formally, the input sequence {xi} |x| i=1 is converted into a vector through a word embedding layer. We use the packed embedding X0 = [x1, · · · , x|x| ] ∈ R |x|×dmodel as input and calculate the model output XL:

        ​​​​where LN(·) is LayerNorm [BKH16]. The calculation method of the FFN part is FFN(X) = gelu(XW1)W2, where W1 and W2 are parameter matrices.

Training: During the training process, we use parallelism (Formula (5)) and block recursion (Formula (7)). Parallelizing within a sequence or block effectively takes advantage of the GPU to accelerate computation. More importantly, chunked recursion is very useful in long sequence training, saving both computation and memory consumption.

Inference: The recursive representation (formula (6)) is used in the inference process, which is very suitable for autoregressive decoding. O(1) complexity reduces memory and inference latency while achieving the same results.

Table 1: Comparing models from different perspectives. RetNet achieves training parallelization, constant inference cost, linear long sequence memory complexity and good performance.

2.4. Relationship and difference with previous methods

        Table 1 compares RetNet and previous methods from different perspectives. The comparison results reflect the "impossible triangle" presented in Figure 2. Furthermore, RetNet has linear memory complexity for long sequences due to chunked recursive representation. We also summarize comparisons with specific methods as follows.

Transformer The parallel representation of RetNet has a similar idea to Transformer [VSP+17]. The most relevant Transformer variant is the Lex Transformer [SDP+22], which uses xPos as positional embedding. As stated in equation (3), the derivation of retention is consistent with xPos. Compared with the attention mechanism, retention removes softmax and enables a recursive form, which significantly improves inference performance.

S4 Different from formula (2), if Qn and Kn are not content-aware, the formula can degenerate into S4 [GGR21], where O = (QK⊺, QAK⊺ , .., QA|x|-1K⊺) * V.

Linear attention: Linear attention variants usually use various kernel functions ϕ(qi)ϕ(kj )/P|x|n=1 ϕ(qi)ϕ (kn) to replace the softmax function. However, linear attention has difficulty in encoding location information effectively, resulting in lower model performance. Furthermore, we revisit the process of sequence modeling rather than aiming to approximate softmax.

AFT/RWKV Attention Free Transformer (AFT) simplifies dot product attention to element-wise operations and moves softmax to the key vector. RWKV replaces AFT’s positional embeddings with exponential decay and runs the model recursively during training and inference. In contrast, retention retains high-dimensional states to encode sequence information, which helps improve the expressive power and performance of the model.

xPos/RoPE Compared with the relative position embedding method proposed for Transformer, formula (3) presents similar results to xPos [SDP+22] and RoPE [SLP+21] expression.

Sub-LayerNorm As shown in formula (8), the retention layer uses Sub-LayerNorm [WMH+22] to normalize the output. Since multi-scale modeling results in heads with different variances, we replace the original LayerNorm with GroupNorm.

3. Experiment

        We conducted language modeling experiments to evaluate RetNet. We evaluate the proposed architecture using various benchmarks, including language modeling performance as well as zero/few-shot learning for downstream tasks. Additionally, for training and inference, we compared speed, memory consumption, and latency.

Table 2: Model size and learning hyperparameters in language modeling experiments.

Figure 5: Perplexity decreases as model size increases. We empirically confirm that RetNet often outperforms Transformer when the model size is larger than 2B.

3.1.Settings

Parameter allocation:For fair comparison, we reassign the parameters in MSR and FFN. Here, for simplicity, we use d to represent dmodel. In Transformer, there are about 4d^2 parameters in the self-attention mechanism, including WQ, WK, WV, WO∈R^d×d, and 8d^2 parameters in FFN, of which the intermediate dimension is 4d. In comparison, the retention layer in RetNet has 8d^2 parameters, among which WQ, WK∈R^d×d, WG, WV∈R^d×2d, and WO∈R^2d×d. Note that the head dimension of V is twice that of Q and K. The expanded dimensions are projected back to d via WO. In order to keep the number of parameters the same as the Transformer, the intermediate dimension of FFN in RetNet is 2d. At the same time, we set the header dimension to 256 in the experiment, that is, the dimension of the query and key is 256, and the dimension of the value is 512. For fair comparison, we keep γ the same in different model sizes, where γ=1−e^(linspace(log 1/32,log 1/512,h))∈R^h, instead of in equation (8) the default value.

Language model training:As shown in Table 2, we train language models of various sizes (i.e., 1.3B, 2.7B, and 6.7B) from scratch. The training corpus is a selected compiled version of The Pile [GBB+20], C4 [DMI+21] and The Stack [KLBA+22]. We added markers to indicate the beginning of the sequence. The training batch size is 4M tokens and the maximum length is 2048. We train the model using 100B markers, which is 25k steps. We use the AdamW [LH19] optimizer, set β1=0.9, β2=0.98, and set the weight decay to 0.05. The number of warm-up steps is 375 and linear learning rate decay is used. In order to ensure training stability, parameter initialization follows the method of DeepNet [WMD+22]. The implementation is based on TorchScale [MWH+22]. We use 512 AMD MI200 GPUs to train the model.

3.2. Comparison with Transformer

Language modeling:As shown in Figure 5, we report perplexity on the validation set for Transformer and RetNet-based language models. We demonstrate the scale curve using three model sizes (i.e., 1.3B, 2.7B, and 6.7B). RetNet is comparable to Transformer in performance. More importantly, the results show that RetNet has advantages in size scaling. In addition to performance, RetNet is very stable in training in our experiments. Experimental results show that RetNet is a strong competitor to Transformer for large language models. Experience confirms that RetNet starts to outperform Transformer when the model size is larger than 2B. We also summarize the language modeling results with different context lengths in Appendix B.

Evaluation of zero-shot and few-shot samples on downstream tasks. We also compared language models on a wide range of downstream tasks. We performed zero-shot and 4-shot learning evaluations using the 6.7B model. As shown in Table 3, the data sets include HellaSwag (HS) [ZHB+19], BoolQ [CLC+19], COPA [WPN+19], PIQA [BZB+20], Winograd, Winogrande [LDM12] and StoryCloze (SC ) [MRL+17]. The accuracy is consistent with the language modeling perplexity in Figure 5. In zero-shot and contextual learning settings, RetNet is comparable in performance to Transformer.

Table 3: Zero-shot learning and few-shot learning based on Transformer and RetNet. The model size is 6.7B.  Table 4: Training costs of Transformer (Trm), Transformer with FlashAttention (Trm+FlashAttn) and RetNet. We report memory consumption and training throughput (words processed per second; wps).

3.3. Training cost

        As shown in Table 4, we compared the training speed and memory consumption of Transformer and RetNet, where the training sequence length is 8192. We also compared with FlashAttention [DFE+22], which improves speed and reduces GPU memory IO through recomputation and kernel fusion. In contrast, we implemented RetNet using pure PyTorch code and left kernel fusion or FlashAttention-like acceleration as future work. We use a block-based recursive retention representation as described in Equation (7). Block size is set to 512. We used eight Nvidia A100-80GB GPUs for evaluation because FlashAttention is highly optimized for the A100. For 6.7B and 13B models, tensor parallelism is enabled.

        Experimental results show that during the training process, RetNet saves more memory than Transformer and has higher throughput. Even compared to FlashAttention, RetNet is still competitive in terms of speed and memory cost. Furthermore, by not relying on a specific kernel, efficient training of RetNet on other platforms becomes easier. For example, we train the RetNet model on an AMD MI200 cluster with good throughput. It is worth noting that RetNet has the potential to further reduce costs through advanced implementations such as kernel fusion.

3.4. Infer cost

        As shown in Figure 6, we compared the memory consumption, throughput and latency of Transformer and RetNet during inference. Transformer reuses KV cache of previously decoded tokens. RetNet uses the recursive representation described in equation (6). We use the A100-80GB GPU to evaluate the 6.7B model in our experiments. Figure 6 shows that RetNet outperforms Transformer in terms of inference cost.

Memory consumption As shown in Figure 6a, due to the KV cache, the memory consumption of Transformer increases linearly. In contrast, RetNet's memory consumption remains consistent even for long sequences, requiring less GPU memory to accommodate RetNet. The additional memory consumption of RetNet is almost negligible (about 3%), while the model weights occupy 97% of the memory.

Throughput As shown in Figure 6b, as the decoding length increases, the throughput of Transformer decreases. In contrast, RetNet achieves higher and length-independent throughput during decoding by leveraging preserved recursive representations.

Latency Latency is an important metric in deployment that greatly affects user experience. We report the decoding latency in Figure 6c. Experimental results show that increasing the batch size will increase the latency of the Transformer. Furthermore, as the input length increases, Transformer's latency grows faster. To make latency acceptable, we have to limit the batch size, which hurts the Transformer's overall inference throughput. In contrast, RetNet’s decoding latency is better than Transformer’s and remains essentially the same under different batch sizes and input lengths.

Figure 6: Inference cost of Transformer and RetNet using 6.7B model. RetNet outperforms Transformer in terms of memory consumption, throughput and latency.

(a) GPU memory consumption of Transformer and RetNet.

(b) Throughput of Transformer and RetNet.

(c) Inference latency under different batch sizes.

3.5. Comparison with Transformer variants

        In addition to Transformer, we also compared RetNet with various efficient Transformer variants, including Linear Transformer [KVPF20], RWKV [PAA+23], H3 [DFS+22], and Hyena [PMN+23]. All models have 200M parameters, 16 layers, and hidden dimensions of 1024. For H3, we set the header dimension to 8. For RWKV, we use the TimeMix module to replace the self-attention layer while keeping the FFN layer consistent with other models for fair comparison. We trained for 10k steps using a batch size of 0.5M labels. Most of the hyperparameters and training corpus remain the same as in Section 3.1.

       Table 5 reports the in-domain validation set and other out-of-domain corpora (e.g., Project Gutenberg 2019-2022 (PG22) [SDP+22], QMSum [ZYY+21], GovReport [HCP+21], SummScreen [CCWG21, SSI+ 22 ]) on the perplexity data. Overall, RetNet outperforms previous methods on different datasets. RetNet not only achieves better evaluation results on in-domain corpora, but also achieves lower perplexity on several out-of-domain datasets. This superior performance makes RetNet a strong successor to Transformer, in addition to the benefit of significant cost reduction (Sections 3.3 and 3.4).

        We also discuss the training and inference efficiency of the compared methods. Assume d represents the hidden dimension and n represents the sequence length. During training, RWKV’s token mixing complexity is O(dn), while Hyena’s complexity is O(dn log n), using fast Fourier transform for acceleration. The above two methods reduce training FLOPS by using element-wise operations in exchange for modeling capabilities. In comparison, the chunk-preserving recursive representation complexity of the model is O(dn(b + h)), where b is the chunk size and h is the header dimension. We usually set b to 512 and h to 256. For large models (i.e. larger d) or sequence lengths, the additional b+h has negligible impact on performance. Therefore, RetNet is very efficient in training without sacrificing modeling performance. During the inference process, compared to other efficient architectures, Hyena has the same complexity as Transformer (i.e., O(n) per step), while other methods can perform O(1) decoding.

Table 5: Perplexity results for language modeling. RetNet outperforms other architectures on both in-domain evaluation sets and various out-of-domain corpora.

Table 6: Ablation experimental results on in-domain and out-of-domain corpora.

3.6.Ablation studies

        We conduct ablation experiments on various design choices for RetNet and report the language modeling results in Table 6. The evaluation settings and metrics are the same as in Section 3.5.

Architecture We ablate the swish gate and GroupNorm described in equation (8). Table 6 shows that the above two components improve the final performance. First, the gating module is critical to enhance nonlinearity and improve model capabilities. Note that we used the same parameter assignments as the Transformer after removing the gating. Second, group normalization in holdout balances the variance of multi-head outputs, improving training stability and language modeling results.

Multi-scale attenuation Equation (8) shows that we use different γ as the decay rate to preserve the head. In the ablation study, we examined removing γ decay (i.e., “- γ decay”) and applying the same decay rate on all heads (i.e., “- multi-scale decay”). Specifically, eliminating γ attenuation is equivalent to γ ​​= 1. In the second setup, we set the γ for all heads to 127/128. Table 6 shows that both the decay mechanism and the use of multiple decay rates can improve language modeling performance.

Head dimension From the circular perspective of equation (1), the head dimension means the memory capacity of the hidden state. In the ablation study, we reduced the default header dimensions from 256 to 64, i.e. 64 for queries and keys and 128 for values. We keep the hidden dimension dmodel unchanged, so the number of heads increases. The experimental results in Table 6 show that larger head dimensions can achieve better performance.

4. Summary

        In this paper, we propose a Retentive Network (RetNet) for sequence modeling, which is capable of multiple representations, including parallel, looping, and chunked looping. RetNet achieves significant improvements in inference efficiency (memory, speed and latency), training parallelization and performance compared to Transformer. The above advantages make RetNet an ideal alternative for large language models, especially considering the deployment advantages brought by O(1) inference complexity. In the future, we hope to extend RetNet in terms of model size [CDH+22] and training steps. Furthermore, retention can work effectively with structured cues [HSD+22b] by compressing long-term memory. We will also use RetNet as the backbone architecture to train multi-modal large-scale language models [HSD+22a, HDW+23, PWD+23]. Additionally, we are interested in deploying RetNet models on various edge devices such as mobile phones.

A.Hyperparameters

B. Grouping results with different context lengths

        As shown in Table 8, we report the language modeling results under different context lengths. To make the data comparable, we use 2048 text blocks as evaluation data and calculate the perplexity only for the last 128 tokens. Experimental results show that RetNet performs better than Transformer under different context lengths. Furthermore, RetNet can leverage longer context to achieve better results.​ 

Table 8: Language modeling perplexity using RetNet and Transformer with different context lengths. The results show that RetNet has a consistent advantage in sequence length.

Guess you like

Origin blog.csdn.net/ADICDFHL/article/details/132049187