Reasoning skills for large language models

This article explores a series of inference optimization techniques for large language models , covering methods such as KV caching, quantization, and sparsity, and shares how to effectively implement these techniques. It is worth reading for anyone who wants to optimize the Transformer model in order to improve the speed or efficiency .

 

The author of this article is machine learning researcher Finbarr Timbers, who was a DeepMind engineer. (This article was compiled and published by OneFlow. Please contact us for authorization for reprinting. Original text: https://www.artfintel.com/p/transformer-inference-tricks)

 

Author |  Finbarr Timbers

OneFlow compilation

Translation|Yang Ting, Wan Zilin

 

1

Key-value (KV) cache

 

Currently, key-value (KV) caching is the most common (and most important) decoder optimization method. In the decoder model, the key and value of the hint will be the same for each decoding iteration. Additionally, once you run a token, the token's keys and values ​​will remain the same on each subsequent iteration. Therefore, you can cache the hints and gradually add the KV tensor of each token to the cache as you decode, which saves a lot of calculations. In the attention mechanism, we can multiply two tensors of shape (batch, context_length, feature_dim) into a query tensor of shape (batch, 1, feature_dim) and a query tensor of shape (batch, context_length, feature_dim) KV tensor multiplication. Therefore, the complexity of sampling is no longer quadratic, which allows us to obtain good decoding (sampling) performance for longer context lengths.

 

In practice, this adds complexity to your implementation, because now you are not just running pure functions, but you have state, so you still need to keep running inference even after a sequence has completed inference ( see Google Implementation of MaxText, https://github.com/google/maxtext ).

 

KV cache requires 2 * n_layers * n_heads * d_head parameters. For GPT-3, where n_layers = 96, n_heads = 96, d_head = 128, this means that each token in the context requires 2.4m parameters. Using typical 16-bit precision, each token requires 5MB; if the context window has 2048 tokens, then 10GB of HBM is needed for the KV cache. It's expensive, but worth every GB consumed.

 

These memory requirements are one of the big reasons why training large language models on consumer GPUs is so difficult. The most powerful consumer graphics card currently available is the 4090, which only comes with 24GB HBM. While its floating-point operations per second (FLOPS) are comparable to enterprise-grade chips, its memory limits are much lower, making it difficult to fit weights and KV cache into memory.

 

2

speculative decoding

 

Speculative decoding is a technique used when computing power is abundant, typically in local inference settings. It takes advantage of the property of modern accelerators that running inference on batches of data takes the same time as running inference on a single data point. Taking the A100 as an example, you can perform inference on up to 160 data points in the same amount of time as a single data point. Therefore, many techniques have emerged to take advantage of this property, such as beam search, MCTS (Monte Carlo Tree Search) or speculative decoding.

 

Speculative decoding consists of two models: a small, fast model and a large, slow model. Since the inference speed of modern decoders is proportional to the number of parameters, using a smaller model can run multiple inferences in the time it takes a large model to run one inference.

 

Modern decoder models (such as the GPT series) use autoregressive sampling technology, that is, to sample a sequence of N word elements, the model will perform N inferences, and each inference must use the result of the previous inference.

 

In speculative decoding, you run these two models in parallel. The fast model runs a batch of inference and guesses which tokens the big model will predict, and then adds those guesses together. Meanwhile, the large model runs in the background, checking whether the smaller model records the same results. Smaller models are able to make multiple guesses in the time it takes a large model to make one inference. However, given that we have excess computing power, large models are able to evaluate all guesses in parallel. Therefore, the only place we pay the cost of sequentially generating sequences is on smaller models.

 

 

The main disadvantage of speculative decoding is that it requires a "draft" model that predicts the output of the larger model, and you have to have both models living in memory on the same machine at the same time (or in a multi-GPU setup on the same node below). This adds complexity and requires extra work because you have to train two models (the original model and the "draft" model). Furthermore, any performance improvements are limited by how accurately the small model can predict the large model. If the small model can always accurately predict the behavior of the large model, then we can just use it! Therefore, there is a fundamental gap in the extent to which speculative decoding can work. HuggingFace claims that it can typically double the decoding rate, which is consistent with the 2 to 3x improvement claimed in the original paper ( https://arxiv.org/abs/2211.17192 ).

 

Recently, a forward decoding (Lookahead Decoding) technology that attempts to improve speculative decoding has emerged ( https://lmsys.org/blog/2023-11-21-lookahead-decoding/ ), which allows the model to generate n-grams , and then recursively match these n-grams without the need for a draft model. This technique is called Jacobi decoding (screenshot from their blog) and could be a potential improvement over greedy decoding. The working principle of Jacobi decoding is to generate n tokens at each point of generating tokens and "guess" the entire sequence. It is then verified against the previous guess, and if the two match, the guess is accepted. This reduces latency without the side effect of becoming greedy decoding in the worst case.

 

Forward decoding further improves this technique by retaining the n-grams generated during decoding and trying to use them as guesses. Given the high correlation between the text that has been generated and the text that will be generated, this also has the potential to significantly improve latency at very low cost. This technique is very clever. Considering this technology has just been released, I am very curious to see how it performs in real-world scenarios.

 

 

3

effective sparsity

 

In the decoder-only Transformer, the core of the model is the attention mechanism, which can be summarized as the following attention equation:

The softmax operation makes non-maximum values ​​very small.

 

Therefore, we multiply the numerical tensor (denoted by V in the attention equation) with a tensor consisting mainly of zeros. As a result, the output of the attention mechanism contains a large number of zeros, up to 97% ( https://x.com/YesThisIsLion/status/1647747069086666752?s=20) . Similarly, we also get a lot of sparsity after each ReLU in multi-layer perceptron networks (MLP).

 

Unfortunately, it's harder to actually take advantage of this now. If there is sparsity in the weights, then a lot can be done with structured sparsity (such as tor ‍ ch.sparse ), but it is not clear how well the system can exploit the sparsity of activations.

 

One optimization that can be made is that if an activation is zero, then loading the weights corresponding to that activation can be skipped and the corresponding calculations avoided. As far as I know, this is not well supported by mainstream tensor computation programs, but for custom inference implementations such as Llama.cpp , this optimization is relatively easy to implement.

 

This is because activation is a function of each token, so the effective sparsity is also randomly distributed across tokens. Therefore, the effect of this optimization decays exponentially as the batch size increases. Assuming that our effective sparsity is X% and the batch size is N, then the probability that all entries for a given activation are zero in the entire batch can be expressed as X^N. I made a table listing the situation with different values ​​of X and N. This attenuation effect is very significant.

 

 

Therefore, it is very difficult to exploit this approach except for batch sizes of 1, and even in this case it is often more efficient to use speculative decoding. But if you want to run inference locally and really need to reduce latency, this can be a great trick.

 

4

Quantify

 

Quantification is one of the more familiar techniques. I have written about quantification before ( https://finbarrtimbers.substack.com/p/efficient-llm-inference ), so I don’t plan to spend too much time on specific methods. It is difficult to accurately measure the effect of quantification. The models used in GPTQ papers and other documents are far from the SOTA models because large laboratories do not disclose the models they use, and the academic community cannot match the resources of large laboratories.

 

 

For example, GPTQ reports quantitative results for the OPT and BLOOM models, which are far inferior to a range of current open source models, let alone GPT-4.

 

Of course, the big labs don't make their research public, and most of the case reports I see come from people trying to run smaller models on consumer-grade hardware, which has very limited memory. I think a lot of hobbyists (i.e. non-large lab researchers) are attracted to quantification by the appeal of running huge models locally. But in fact, quantification has no inherent advantages! From first principles, if you have two models with the same number of bits, they should have the same number of tokens/second and should have similar performance levels. It would only make a big difference if you did a poor job of using the number of bits in a higher precision format.

 

But the view in the literature is inconsistent with my intuition. The GPTQ paper mentioned above found minimal performance degradation when quantizing the model to as low as 4x accuracy. I believe this is because worse performing models are more likely to keep their performance unimpaired during quantization. If we assume two identical LLMs, one has been trained with 2 trillion words and the other has been trained with 500 billion words (referred to as LLM-2T and LLM-500B respectively), when doing quantization, I think after more A model trained on tokens will suffer more in performance since it should make fuller use of these tokens. We still expect the quantized LLM-2T to outperform the LLM-500B, but I think the performance drop from LLM-2T to quantized LLM-2T will be greater than the drop from LLM-500B to quantized LLM-500B. Significantly.

 

Note: Although the above argument is persuasive, there is actually no relevant literature to support it. Quantification seems to be very close to a "free lunch" indeed.

 

Recent research, such as the paper on k-bit inference scaling law ( https://arxiv.org/abs/2212.09720 ), conducted extensive experiments on a series of LLM architectures and concluded that different bit allocations have an impact on model performance. Impact. They studied the trade-off between a model using N parameters at a given level of accuracy versus a model using 2N parameters and half the accuracy. The results are very impressive, almost indistinguishable from performance without quantization (at least for 4 or more bits).

 

Basically, they found that the precision can be reduced to 4 bits without losing any performance, and the quantization results in almost no trade-offs. You can run a model 4x smaller without significantly reducing performance. This is helpful since inference performance on modern accelerators is equal to the number of bits processed (i.e. you can get more operations per second using less precision).

 

Therefore, my conclusion is: It is recommended to adopt the suggestions of "k-bit reasoning paper". However, for production loads I'd be a bit hesitant about using less than 8 bits of precision. fp8 is currently the lowest precision floating point format supported natively by modern accelerators, and even then, support is limited. I would recommend training and inference at fp8 precision and see if the possible accuracy loss from further quantization is acceptable for your use case. It's hard for me to recommend using lower levels of precision in production when there's a lack of native support from these platforms (such as Nvidia and the Torch/JAX team).

 

From what I understand from the literature (which matches my intuition), fp8 is strictly better than int8, but has limited support in hardware. If you are in a GPU-rich organization and can use the H100 for all tasks, go with fp8. Otherwise, you can also use int8, and it's much easier in comparison (PyTorch makes it quite easy, although the API is less stable).

 

Regarding actually doing model quantization, the PyTorch team has written an article on how to do it specifically ( https://pytorch.org/blog/accelerating-generative-ai/ ) and provides a series of APIs to simplify the operation, although They are less stable. Additionally, bitsandbytes is another excellent quantization library, but I haven't personally used it yet.

 

(Special thanks to @cis_female for discussing the intricacies of sparsity with me, and to @nostalgebraist for correcting the error in the quantization section. I now think the evidence suggests that there is a very small performance trade-off in quantizing to at least 4 bits or more.)

 

 

 

Everyone else is watching

 

 

Try OneFlow: github.com/Oneflow-Inc/oneflow/

 

This article is shared from the WeChat public account - OneFlow (OneFlowTechnology).
If there is any infringement, please contact [email protected] for deletion.
This article participates in the " OSC Source Creation Plan ". You who are reading are welcome to join and share together.

IntelliJ IDEA 2023.3 & JetBrains Family Bucket annual major version update new concept "defensive programming": make yourself a stable job GitHub.com runs more than 1,200 MySQL hosts, how to seamlessly upgrade to 8.0? Stephen Chow's Web3 team will launch an independent App next month. Will Firefox be eliminated? Visual Studio Code 1.85 released, floating window Yu Chengdong: Huawei will launch disruptive products next year and rewrite the history of the industry. The US CISA recommends abandoning C/C++ to eliminate memory security vulnerabilities. TIOBE December: C# is expected to become the programming language of the year. A paper written by Lei Jun 30 years ago : "Principle and Design of Computer Virus Determination Expert System"
{{o.name}}
{{m.name}}

Guess you like

Origin my.oschina.net/oneflow/blog/10320747