Preface
There are two factors that made this article possible
- The first factor is that when I led the LLM project team in Changsha to do the paper review GPT project, we encountered many engineering problems ( LLM project More, you will gradually find that there is no secret in the current model. The technical architecture/direction selection is no secret. In the end, it is the continuous optimization of various engineering details), such as data issues, and then For example, the problem of the context length of the large model itself
The former has been solved. For details, please see this article " Source code interpretation and fine-tuning of academic paper GPT: from chatpaper, gpt_academic to the third part of July paper review GPT"
But the latter is relatively troublesome because there are more than 10,000 papers in the review corpus The length is basically more than 10,000 words. From the previous articles in this blog, we can know that the context length of most models basically does not exceed 8K.Model Corresponding context length Paper review performance (any length within 8K is not enough) GPT3.5 4-16K (unified to 16K on 11.7 days later) The 16K effect is yet to be tested
In addition, the 16K fine-tuning interface of 3.5 was opened on November 7, 2023GPT4 8K-32K (upgraded to 128K on November 7th) To be tested Calls 2048 LLaMA2 4096 LLaMA2-long(23rd September 27th edition论文) 32K The effect is yet to be tested LongAlpaca-7B/13B/70B based on LongLoRA technology 32K or more The effect is yet to be tested Baichuan-7B/13B、Baichuan 2-7B/13B 4096 ChatGLM-6B 2000 ChatGLM2-6B 8-32K The effect of 32K is yet to be determined
- The second factor is that this article was originally part of the content of ChatGLM2-6B and was compiled together with the content of the first generation ChatGLM-6B. One of the more prominent features of ChatGLM2-6B is that it supports 32K context, and ChatGLM2 is a 32K context implemented based on FlashAttention technology
Therefore, in order to explain clearly the principles related to FlashAttention, FlashAttention2, etc., the previous article became longer and longer, so I specifically included FlashAttention-related The content is independently extracted into this article
As for LLaMA2-long and LongAlpaca-7B/13B/70B based on LongLoRA technology, they will be explained in the next blog
This article, like other large model-related articles in this blog, pays great attention to readability.
- For example, in order to continuously improve the readability, this article will be revised repeatedly in the near future, carefully focusing on the level, wording, and even typesetting and punctuation of the title. If it is not easy to understand, I would rather not write it.
- If you do not understand a certain content or a certain formula in a certain section, please feel free to leave a message in the comments of this article, and be sure to revise it in time so that you can understand (Friendly reminder , this article assumes that you are already familiar with transformer.If you are not familiar with transformer, it is recommended to read this article first: Transformer popular notes: From Word2Vec and Seq2Seq, gradually understand GPT and BERT, especially the third part)
Part 1: Transformer’s space-time complexity and standard attention issues
FlashAttention is a new attention algorithm that is IO-aware, fast and memory-efficient and was proposed by Stanford and the State University of New York in June 2020. "The corresponding paper is :FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness, this isits GitHub address‖
What kind of problem does it want to solve?
- First of all, the maximum sequence length of input and output of large language models such as GPT3, LLaMA, ChatGLM, and BLOOM is only 2048 or 4096. What is the difficulty in extending to longer sequences? The essential reason is that the computational complexity and space complexity of the transformer model are both , where is the sequence length
- In this way, FlashAttention proposes a precise attention that accelerates calculations, saves video memory and IO awareness, and can effectively alleviate the above problems.
The large open source model LLaMA launched by Meta and the large open source model Falcon launched by the United Arab Emirates both use Flash Attention to accelerate calculations and save video memory. Currently, Flash Attention has been integrated into pytorch2.0, and open source frameworks such as triton and xformer have also been integrated and implemented.
1.1 Transformer computational complexity——Self-Attention layer and MLP layer
To put it simply, the computational complexity is proportional to the square of the sequence length. You can look at a small example. For example, the sizes of the two multiplied matrices are ( ) and (), one way to calculate matrix multiplication is to use each row of the first matrix and each column of the second matrix to do dot multiplication
Because we need to do dot multiplication of each row of the first matrix with each column of the second matrix, a total of dot multiplications are required. Each dot multiplication requires multiplications, so the total complexity is
To understand it accurately, when the input batch size is and the sequence length is ,
The calculation amount of the layer transformer model is , represents the dimension of the word vector or the dimension of the hidden layer (the hidden layer dimension is usually equal to the word vector dimension)
But how is this result calculated step by step? Next, let’s break down this calculation process in detail.
1.1.1 Computational complexity of Self-Attention layer
First of all, we know that the transformer model consists of identical layers, each layer is divided into two parts: self-attention block and MLP block
The model parameters of the self-attention layer have two parts, one part is , , The weight matrix, , and bias, the other part is the output weight matrixand offset, the final result is:
How is it calculated specifically?
- The first step is to calculate, ,
That is,
The input and output shapes of this matrix multiplication are
calculation The amount is:- Calculation
Partial import and export shape
Calculation amount:- Calculate the weight on
The input and output shapes of this part of the matrix multiplication are < /span>The calculation amount is:
- Linear mapping after attention, the input and output shapes of matrix multiplication are
The amount of calculation is
The final output result of the self-attention layer is
1.1.2 Computational complexity of MLP layer
The MLP block consists of 2 linear layers, which ends with
How is it calculated?
Generally, the first linear layer maps the dimensions from to , and the second linear layer maps the dimensions from Mapped to
- The weight matrix of the first linear layer has a shape of , which is equivalent to changing the dimension from Map to , the input and output shapes of matrix multiplication are , and the calculation amount of is
- The weight matrix of the second linear layer has a shape of , which is equivalent to changing the dimension from Mapped to , the input and output shapes of matrix multiplication are , and the calculation amount of is
Adding up the calculation amounts shown in bold in all the above tables, the calculation amount of each transformer layer is approximately
1.1.3 Calculation amount of logits:
In addition, another major part of the calculation is the calculation of logits (after all, the number of parameters of the word embedding matrix is also large), mapping the hidden vector to the vocabulary size. To put it bluntly, the word vector dimension is usually equal to the hidden layer dimension , the parameter amount of the word embedding matrix is , and the weight matrix of the final output layer usually shares parameters with the word embedding matrix " Explain, as Teacher Du Qiyue said, this is an important point in the transformer. Parameter sharing can reduce the amount of parameters. The word embedding matrix is [vocab_size, hidden_size], and the output layer matrix is [hidden_size, vocab_size], can be shared‖
. The input and output shapes of its matrix multiplication are , and the calculation amount is < /span>
Therefore, for a layer transformer model, when the input data shape is , the calculation amount of one training iteration is the above The synthesis of three parts, namely:
1.2 Transformer’s space complexity——Self-Attention layer and MLP layer
The size of the memory activated in the middle is , where is the number of attention heads
Large models usually use mixed precision training during the training process, and the intermediate activation values are generally float16 or bfloat16 data types. When analyzing the memory usage of intermediate activation, it is assumed that the intermediate activation value is saved in float16 or bfloat16 data format, and each element occupies 2 bytes. The only exception is that the mask matrix of the dropout operation only occupies 1 bytes per element. In the following analysis, the unit is bytes, not the number of elements.
Each transformer layer contains a self-attention block and an MLP block, and each corresponds to a layer normalization connection.
1.2.1 Intermediate activation of Self-Attention block
The calculation formula of the self-attention block is as follows:
Finally, the memory size occupied by the intermediate activation of the self-attention block is:
How is it calculated specifically?
- For , their common input needs to be saved, which is the intermediate activation. The shape of the input is , the number of elements is , and the size of the video memory occupied by is
- For matrix multiplication, the intermediate activation needs to be saved, and the shape of both tensors is , The total size of the video memory occupied is
- For the function, the input of the function needs to be saved , occupies a video memory size of < a i=4>, where represents the number of attention heads
where the shape of
is: The shape of is: , the number of elements is , and the occupied video memory is a>- After has calculated the function, the dropout operation will be performed. A mask matrix needs to be saved. The shape of the mask matrix is the same as , and the size of the video memory occupied by is
- The attention calculated on , that is, , needs to be saved with a size of < a i=4>; and , the size is , . The total size of the video memory occupied by the two is
- computes the output mapping and a dropout operation. Input mapping needs to save its input, the size is ; dropout needs to save the mask matrix, the size is , both occupy video memory The total size is
Therefore, by adding the above intermediate activations, the intermediate activation of the self-attention block occupies a video memory size of
1.2.2 Intermediate activation of MLP block
The calculation formula of the MLP block is as follows: . Finally, for the MLP block, the intermediate activation value that needs to be saved is
How is it calculated specifically?
- The first linear layer needs to save its input and occupies a video memory size of
- The activation function needs to save its input and occupies a video memory size of
- The second linear layer needs to save its input and occupies a video memory size of
- Finally, there is a dropout operation, which needs to save the mask matrix and occupies a video memory size of
1.2.3 Two intermediate activations of layer norm that need to be saved
In addition, the self-attention block and the MLP block correspond to a layer normalization respectively. Each layer norm needs to save its input, the size is , the intermediate activation that needs to be saved for 2 layer norm is
To sum up,the intermediate activation that each transformer layer needs to save occupies a video memory size of
For the layer transformer model, there is also an embedding layer and the final output layer. The embedding layer does not require intermediate activation. In general, when the hidden dimension is relatively large and the number of layers is deep, the intermediate activation in this part is very small and can be ignored
Therefore, For the layer transformer model, the memory size occupied by the intermediate activation can be approximated as "< /span>" "Analysis of parameters, calculations, intermediate activations, and KV cache of the transformer model"For more analysis, see this article
Through the content of the above two sections, we can see that The calculation amount and storage complexity of the transformer model increase with the sequence length size. This limits the maximum sequence length of large language models grows quadratically
Secondly, GPT4 has expanded the maximum sequence length to 32K, and Claude has expanded the maximum sequence length to 100K. These efforts must be Some optimization methods have been adopted to reduce the complexity of the native transformer. How to optimize it specifically?
We know that each transformer layer is divided into two parts: the self-attention block and the MLP block, but the item in the above calculation and the intermediate activation The items are all generated by the self-attention block and have nothing to do with the MLP block
1.3 Two problems with Standard Attention: high memory usage and high number of HBM reads and writes
- To review, the calculation process of the attention mechanism in the transformer is ( Again, if you have forgotten the details related to the transformer, it is recommended to read this first: Transformer popular notes, if you forget what softmax is, review this article:How to understand Word2Vec in a popular way):
Among them, , where is the sequence length, is the length of each attention head Dimension, the output can be recorded as - The above formula can be broken down into the following three steps: In the standard attention implementation, must be written back to the HBM (This HBM will be explained soon below ), occupying memory, usually
For example, for GPT2, is much larger than the memory required by In short, the memory required by the attention matrix , ; for GPT3, , -
The figure below shows the implementation process of standard attention
There are a total of eight HBM matrix read and write operations. These eight read and write operations are:
The first line reads twice and writes once , a total of three read and write operations
The second line reads once for and writes once for , a total of two read and write operations
The third line reads twice, writes once, and reads and writes three times in total
Add some background knowledge
- Although there have been many approaches to approximate attention that attempt to reduce the computational and memory requirements of attention. For example, sparse approximation and low-rank approximation methods reduce the computational complexity to linear or sublinear sequence length.
- However, these approximate attention methods have not been widely used. Because these methods focus too much on reducing FLOPS (number of floating point calculations) and ignore the memory access overhead of IO reading and writing, they do not effectively reduce running time (wall-clock time).
- In short, in modern GPUs, the computing speed has far exceeded the video memory access speed.The bottleneck of most computing operations in the transformer is video memory access. For operations with limited video memory, IO awareness is very important, because video memory reading and writing take up most of the running time
The memory of the GPU is composed of multiple memories of different sizes and different read and write speeds. The smaller the memory, the faster the reading and writing speed. For A100-40GB, the memory classification chart is as follows
- SRAM memory is distributed on 108 streaming multi-processors, each processor is 192K in size, totaling
is equivalent to calculation block, but has small memory- High Bandwidth Memory HBM (High Bandwidth Memory), which is what we often call video memory, has a size of 40GB. The read and write speed of SRAM is 19TB/s, while the read and write speed of HBM is only 1.5TB/s, which is less than 1/10 of SRAM
which is equivalent to slow calculation but large memory
In short, the computational complexity and space complexity of the self-attention block, the core component of the transformer, are the quadratic power of the sequence length and for self-attention block, except that the two large matrix multiplications are computationally limited (, ) , others are memory-limited point-by-point operations ( , such as the mask operation on , 's softmax operation and 's dropout operation. The performance of these point-by-point operations is limited by memory bandwidth and will slow down the running time. ) That is, there are two problems in the standard attention implementation:
- takes up a lot of video memory. During the process, the complete attention matrix is instantiated , resulting in memory requirements
- HBM has many reads and writes, which slows down the running time (wall-clock time)
The following Memory-efficient Attention in Section 2.1 and Flash Attention in Section 2.2 are to solve the above two problems respectively.
Part 2 Forward pass of FlashAttention: Memory-efficient Attention/Flash Attention
2.1 Memory-efficient Attention: Reduce the memory complexity from square to linear, but the number of HBM accesses is still square
In the attention calculation process, the main challenge in saving video memory is that the softmax and columns are coupled. The method is to calculate the normalization factor of softmax separately to achieve decoupling
- In order to simplify the analysis, ignore the step of "subtracting the maximum value" when calculating softmax.
Note the > Column , ’s column is , there are Define the normalization factor of softmax as: - Denote as the th column vector of , then output The th column vector is:
- After calculating the normalization factor, it can be obtained by repeated accumulation
In this way, the calculation order is changed through the memory-efficient attention mechanism. Compared with Standard Attention, the memory-efficient attention mechanism reduces the memory complexity from Reduced to
This method has been used in "Online normalizer calculation for softmax" and "Self-attention Does Not Need Memory", which is called "lazy softmax". This method avoids instantiating the complete attention matrix, thereby achieving the purpose of saving video memory. However, the number of HBM accesses is still , so the running time has not been reduced
2.2 Flash Attention: Reduce the number of HBM reads and writes through kernel fusion,Avoid frequent reading and writing of data from HBM
As mentioned above
- In the standard attention implementation, the performance of attention is mainly limited by memory bandwidth and is memory-limited. Frequently reading and writing matrices from HBM is the main bottleneck affecting performance
- Although approximate attention methods such as sparse approximation and low-rank approximation reduce computational FLOPs, for memory-limited operations, the bottleneck of running time is the time taken to read and write data from HBM. Reducing the computational load cannot effectively reduce the running time. time(wall-clock time)
- For memory-limited standard attention, Flash Attention is IO-aware, and the goal isAvoid frequently reading and writing data from HBM
Therefore, it is very important to reduce the number of reads and writes to HBM and effectively utilize higher-speed SRAM for calculations. For operations whose performance is limited by memory bandwidth, a common method of acceleration is kernel fusion. Typical methods of this operation are as follows: For three steps:
- Each kernel loads input data from the low-speed HBM into the high-speed SRAM
- In SRAM, calculations are performed
- After the calculation is completed, the calculation results are written from SRAM to HBM.
In this way, you can avoid repeatedly executing "read input data from HBM, perform calculations in SRAM, and finally write the calculation results to HBM" and merge multiple operations into one operation.< a i=1>Reduce the number of times to read and write HBM(It should be noted that model training usually affects the effect of operator fusion, because in order to calculate the gradient in backward pass, it is usually necessary to transfer certain Intermediate results are written to HBM)
Some students may not understand the above explanation. In fact, the principle is very simple, that is, the following two sentences
- If writing SRAM back to HBM just to (re)load it to calculate softmax
- Then it is possible to save it in SRAM, perform all the intermediate steps, and then write the final result back to the HBM
The former is shown on the left side of the picture below, and the latter is shown on the right side of the picture below (picture belowSource)
2.2.1 Comprehensive explanation of block computing attention tiling——Kernel integration needs to meet the memory size of SRAM, but the SRAM memory is too large small
Although through kernel fusion, multiple operations are merged into one operation and the use of high-speed SRAM for calculation can reduce the number of reads and writes to HBM, thereby effectively reducing the running time of memory-limited operations. But there is a problem
- The memory size of SRAM is limited, and it is impossible to calculate the complete attention at once. Therefore, the calculation must be performed in blocks so that the memory required for block calculation does not exceed the size of SRAM.
Quite Because the memory is limited--> Reduce the number of HBM reads and writes--> Kernel fusion--> Meet the memory size of SRAM--> Block calculation, so the block size block_size cannot be too large, otherwise it will cause OOM - What is the difficulty of block calculation?
The calculation process of the attention mechanism is "matrix multiplication--> scale --> mask --> softmax --> dropout --> matrix multiplication". Matrix multiplication and The block calculation of point-by-point operations (scale, mask, dropout) is easy to implement, but the difficulty lies in the block calculation of softmax. Sincecalculates the normalization factor (denominator) of softmax, it is necessary to obtain the complete input data, so it is more difficult to perform block calculation
How do you understand the sentence above "Since the complete input data needs to be obtained when calculating the normalization factor (denominator) of softmax, it is more difficult to perform block calculations"?
Let’s first review the calculation formula of softmax.
- Considering the vector , the calculation process of native softmax is as follows:
- In actual hardware, the range of floating point number representation is limited
For float32 and bfloat16, when , Therefore, all deep learning frameworks now use the "safe softmax" calculation method is defined as the maximum value in and Therefore, in order to avoid the problem of numerical overflow and ensure numerical stability, calculations are usually Will "subtract the maximum value", called "safe softmax" will become very large or even become inf, causing the problem of data overflow- When training a language model, the cross-entropy loss function is usually used. The cross-entropy loss function is equivalent to executing the log_softmax function first, and then calculating the negative log likelihood function
and when calculating log_softmax, "subtracting the maximum value" will also be executed, which not only avoids numerical overflow, improve numerical stability, and also speed up calculationsIn summary, to calculate how much attention a particular th token in the input sequence pays to other tokens in the sequence, all these scores need to be readily available in SRAM (here will explode. Soon (sequence length) can be 1,000 or even 100,000 tokens, represents), but the capacity of SRAM is limited.
In short, the main idea of tiling is to calculate attention in chunks. The difficulty of block calculation lies in the block calculation of softmax . The columns of softmax and are coupled. By introducing Two additional statistics are used for decoupling (the former is similar to the maximum score, and the latter is similar to the sum of exp scores) to implement block calculation
2.2.1.1 Comprehensive understanding of block calculation attention tiling through 23 formulas
Let’s start from the beginning and sort it out comprehensively (the following 23 formulas are explained fromthis)
-
Considering the vector , the calculation process of the native softmax is as follows:
Among them, the molecule is to the vector < The element in /span> The sum of all elements is taken after the exponent, which ensures that the output of the softmax function is a probability distribution, that is, the sum of all elements is 1 is the corresponding element in the vector takes the exponent, and the denominator
is defined as the maximum value among
is a new vector in which each term is equivalent to the numerator of the standard softmax of Equation 4, that is, based on each term of , in its exponential term is subtracted from one
is the summation term in the denominator of softmax. For the convenience of description later, the summation term in Formula 7 will be called "EXP summation term”
Consider a vector of size 2d, divide it into two blocks:
Among them
In other words, the sub-vector is the first half of the original vector , and the sub-vector Is the second half of the original vector
Assume that in the block calculation, it is processed first and then
Then use Formula 5 to Formula 8 to calculate its "local" for the sub-vector. The calculation process is as follows: Formula 9-12 Show
Obviously, the obtained so far cannot be regarded as the final result of the sub-vector. The reason is very simple
First, the maximum value subtracted from the exponential term in Formula 10 should be the maximum value of the entire vector, not the sub-vector The maximum value of both, the EXP summation term of the denominator in Formula 12 should be the summation term about the entire vector, not It is just the summation of all elements in the subvector Because the obtained by the above calculation is not the final result, so it will It is called "local" Next we will introduce how to save some additional variable values and update > Method of value First, after processing the subvector , save and , saving just these two scalars is much less expensive than saving the entire subvector Secondly, two global scalars need to be saved: and represent the current maximum value, because only has been processed so far, so Temporarily: represents the global EXP summation item. Because only has been processed so far, so for the time being: and then use a similar method to to process , the following results can be obtained:-
In the same way, the softmax obtained by Formula 16 is also local rather than global.
But after processing , you can use (), as shown in the following formulas 17 and 18:) and ( information to update the two previously saved global scalars -
The meaning of formula 17 is very simple: the updated global maximum value is the maximum value of the previous maximum value and < a i=3>The larger one -
Formula 18 is a method of updating the global EXP summation term of Wait a minute, how did this come about? Shouldn't it be? Taking as an example, we say is "local" because So far, only the information of has been used. It is needed to update to "global".Expand the calculation formula 15 of slightly, i.e., we can get:
-
It can be seen that the reason why is "local" rather than "global" is that the max value it subtracts is "local", so you only need to replace this max value with global
For this purpose, can be down-transformed to become global - i.e.
At this time, is updated to: "global"
This formula explains, , where a> a>, which can be updated is defined by formula 14, that is, , Let’s look at the numerator part first Since the current numerator and denominator are both local, they need to be updated to the global level, it can be seen that according to formula 16 that is, and the softmax value can also be directly updated is based on the above update method respectively, and then sums them up to get the current The EXP summation term to the global and Returning to formula 18, we can see It first uses this global update method to update represents the current maximum value , represents the current maximum value corresponding to to be "global", just multiply it by a term: When you need to update a -
i.e.
When comparing before and after the transformation, the conclusion obtained from formula 20 above is once again confirmed, that is: if you want to From a local value to a global value, just multiply it by a term: , where a>
and let’s look at the denominator part< /span>. This can be done by the following formula: to , we actually only need to replace the denominator from represents the current maximum value , represents the current maximum value corresponding to -
where is calculated from Formula 18
Okay, here comes the question
Question 1 Many friends on the Internet also I have expressed doubts about this, that is, why the denominator in formula 22 is instead of
Answer: The reason is very simple. Consider why we use softmax: it is a vector Each element of is assigned a probability value between 0 and 1, such that the sum of these probabilities is 1
When we say "global", we want to provide a probability value for each element of the entire data set. Probability is assigned to an element, not just to a subset of the data set
So when you have a data stream split into two parts: and < /span>, and it still only considers this in Formula 20 is just the global version of Answer: Question 2 The next question may come again, and some students may ask it soonGlobal effect after merging and alone, you must consider combined) softmax, you can't just consider and , in order to calculate the entire data stream ( and calculate its softmax, then, you see . You first see Question 3 Formula 20 and Formula 19 are only used, then it What is the difference between the two? Formula 19: The maximum value here is , that is, local maximum. This means that for this data block, we compare each element with the maximum value within itFormula 20: Here the maximum value isThe update can be achieved as follows: Finally, combining Formula 21 and Formula 22, So, the main difference is that they use The reference maximum values of are different: Equation 19 uses a local maximum, while Equation 20 uses a more global maximum. This transformation is for numerical stability, ensuring that we will not encounter numerical overflow problems when we calculate the exponent of e with the maximum of all elements observed so far . This means that we compare each element of and , which is the global maximum of
-
Look carefully at Formula 23. When we update the value of , we use the additional saved quantities mentioned earlier: local value of a>, local maximum from Equation 15, global maximum from Equation 13 a> values in FlashAttention This is the essence of dynamically updating or All update processes do not need to use the vector value of Update, value of to correct the in the first three items above with In the same way, You can replace , from Formula 18 Global EXP summation term, from Formula 17, local EXP summation term from Equation 16
The above is actually an incremental calculation process.
- We first calculate the local softmax value of a block and then store it
- When the next block is processed, the old softmax value can be updated based on the new global maximum value and global EXP summation term at this time, and then the next block is processed, and then updated
- After all blocks are processed, the softmax values of all blocks at this time are "global"
2.2.1.2 A brief summary of calculating attention tiling in blocks
Maybe your CPU has been dry burned. In order to alleviate the brain burn, let’s finally summarize the above process through a simple example.
For two vectors , the softmax calculation of the decoupled splicing vector :
By maintaining two additional statistics , softmax can be calculated in blocks. It should be noted that GPU multi-threading can be used to calculate the softmax of multiple blocks in parallel at the same time. In order to make full use of hardware performance, the calculation of multiple blocks is not serial but parallel
I seem to see a vague feeling of anxiety on your face. It's okay. Don't worry. July understands. The simple formula is relatively obscure after all. Let's use an example to vividly illustrate how to calculate softmax in blocks.
Calculate softmax for vector [1,2,3,4] and divide it into two blocks [1,2] and [3,4] for calculation
Calculate block 1:
Calculate block 2:
Combine to get the complete softmax result:
2.2.1.3 Forward calculation algorithm of Flash Attention algorithm
Simplifying the analysis while ignoring mask and dropout, the forward calculation process of the Flash Attention algorithm is as follows
As can be seen from the above figure, the algorithm performs an outer loop on the dimension of and an inner loop on the dimension of (and in In the code implementation of triton, an outer loop is used in the dimension of , and in < Make an inner loop on the dimension of a i=6>)
For the sake of detail, I will explain the above 16 lines of code line by line. To facilitate everyone’s understanding, I will quote a flow chart drawn by marsggbo on Zhihu. You can refer to this flow chart to improve your understanding of the relevant code.
First, there are basic conditions:
Among them, is the sequence length, is the dimension of each attention head, is the size for
Set block sizes ,
Calculate row/column block sizes. Why ceil()? Because the query, key, and value vectors are -dimensional, we also need to combine them into the output -dimensional vector. So this size basically allows us to maximize the capacity of the SRAM with q k v and 0 vectorsTake GPT2 and A100 as an example:
The SRAM size of A100 is
in GPT2 , is , and the dimension of the intermediate result is , the corresponding dimension of
Initialize the output matrix with all 0s, which will act as an accumulator
Similar to the above, Its purpose is to save the cumulative denominator of softmax - the sum of exp scores
is similar to the above , which saves the maximum score line by line and is initialized to -inf. Because we are going to Max operator it, whatever the Max of the first block is, it will definitely be greater than -inf
Divide , and into chunks < a i=4> Specifically, is divided into blocks along the row direction, and the size of each block is Divided into blocks along the row direction, the size of each block is and
Split into blocks
where and are blocks The size is the same, and it is also divided into blocks along the row direction. The size of each block is
. As for the vector and the vector is divided into blocks, and the sub-vector size of each block is
. Combining the above two steps 3 and 4, we can get each The relationship between the blocks is as followsfor 1 ≤ j ≤ Tc do
start across Column loop (i.e. outer loop, controlled by , from previous column to next column), i.e. across key/value vectors, i.e. traversal < a i=7>, a total of cycles timesLoad Kj , Vj from slow HBM to on-chip fast SRAM.
Will and blocks Load from HBM to SRAM (their size is ). At this point in time we still have 50% of SRAM unoccupied (dedicated to and )
for 1 ≤ i ≤ Tr do
start across lines Inner loop (from the previous row to the next row), that is, across the query vector, a total of times, you can only traverseLoad Qi, Oi, ℓi, mi from HBM to on-chip SRAM.
)和 ()块到()和< a i=8> ()Additional to SRAMHere you need to ensure that and can be loaded into SRAM (including all intermediate variables)
On chip, compute , Immediately
This step calculates () and transpose ( ) to obtain the block Attention Score. The Attention Score obtained in the standard Transformer calculation is a matrix of , as shown in the figure below As shown (in the figure , , )
Standard Transformer need to calculate the Attention Score includes the entire matrix (gray), and the Attention Score calculated for each block is shown in the blue and orange areas in the figureFor another example, assume that the outer loop index is j (j=3), the inner loop index is i (i=2), N is 25, and the block size is 5. The following is the result just calculated (assuming a 1-based index) :
On chip, compute, ,
Use the scores calculated in the previous step to calculate , and
serve to calculate on the score >, calculate the maximum value in each row of itBased on , calculate the exponential term (normalized - take the row maximum and subtract it from the row score, then EXP):
Then based on , calculate the EXP summation term (row-wise sum of matrix ):
On chip, compute ,
This step is to calculate and , for example, also
can reuse the chart above:
contains the row-wise maximum of all previous blocks (j=1 & j=2 , indicated in green), contains the row-wise maximum value of the current block (indicated in yellow). In order to get we just need to take a maximum value between and , Also similar to
and above, use formulas 17, and 18, , to update respectively. and, have the same meaningIn order to better understand the formula in this line, you must first understand that the purpose of calculating multiple lines together is Batch calculation In the figure, each small block has multiple rows (3 rows in the figure), but there will be no interaction between the data between the rows. It is just a Batch calculation strategy. . The real meaning of blocking is on the column, because softmax is performed along the column direction
So for the convenience of understanding, it can be imagined that is equal to 1, that is, only one block of size in the above figure is calculated each time
Based on the above simplified method, let’s look at the entire softmax update process. We use to represent the Attention Score of each row, and to represent the of each row.
Since the Batch calculation is not considered now, the Attention Score of each processing is a vector, as shown in the figure above , we first use formula 5 to formula 8 to calculate it Local
gets. At this time, only the first two positions in have values, corresponding to local valueWe then do the same with each row below it (the first two columns of the green part)
Then process , first use formula 5 to formula 8 to calculate its local , and then use formula 23 a i=3> updates (note that from line 11 above, it can be seen that is equivalent to ): < /span>
(recorded as Formula 24)
When is processed, continue to apply formula 24 to update:
(recorded as Formula 25)
Let’s go one step further and try to update the output directly, not just the value . The method is actually very simple. Just multiply the corresponding value of after each dynamic update of :
(recorded as Formula 26)
Comparing Formula 26 with the pseudocode above, we can see that the formula in the pseudocode is just the matrix version of Formula 26. At this point, you can see that block Self-Attention calculation can be achieved using Equation 26
Write
UpdateJapaneseend for
end for
Return O.
2.2.2 Recalculation
As mentioned above, model training will affect the effect of kernel fusion. In order to calculate gradients through backward transfer, some intermediate results usually need to be written back to HBM during forward calculation. This will generate additional HBM read and write times and slow down the process. operation hours. Therefore, Flash Attention does not save a large intermediate result matrix for backward pass
In the standard attention implementation, when calculating the gradient of in backward pass, the intermediate matrix of needs to be used , but these two matrices are not saved. The trick here is to recalculate, save two statistics, quickly recalculate Attention on high-speed SRAM during backward pass, and recalculate the attention matrix in blocks a>. Compared with the standard attention method, which reads a large intermediate attention matrix from the HBM, the recomputation method is much faster.
In general, Flash Attention avoids instantiating the complete attention by adjusting the calculation order of attention and introducing two additional statistics for block calculation. Matrix reduces the memory complexity from to . In addition, for standard attention with limited memory, Flash Attention greatly reduces the number of HBM accesses through kernel fusion and block calculation. Although the recalculation in the backward pass adds additional computational FLOPs and reduces the running time, Faster calculations (7.6 for GPT2)
2.2.3 kernel fusion
In order to simplify the analysis, the mask and dropout operations were ignored when introducing attention above. The details of Flash Attention forward pass are introduced in detail below. Given the input, calculate the attention output
Among them, is the scaling factor of softmax, typically such as . The MASK operation sets some elements in the input to −∞. After calculating the softmax, it becomes 0, and other elements remain unchanged
The main difference between the causal-lm structure and the prefix-lm structure is that the MASK matrix is different. acts on each element of point by point, and sets the element to 0 with the probability of , with The probability of a> sets the element to
tiling block calculation allows us to usea CUDA kernel to perform all operations of attention. Load input data from HBM, perform all calculation operations (matrix multiplication, mask, softmax, dropout, matrix multiplication) in SRAM, and then write the calculation results back to HBM. Fusion of multiple operations into one operation through kernel fusion avoids repeatedly reading and writing data from HBM
Kernel integration is shown in the figure below, the picture comes fromhttps://www.bilibili.com/video/BV1Zz4y1q7FX/
Considering the mask and dropout operations, the forward calculation process of the complete Flash Attention algorithm is as follows:
// To be updated..
Part 3 FlashAttention2
// To be updated
References and Recommended Reading
- Transformer popular notes: gradually understand GPT and BERT from Word2Vec and Seq2Seq
- Analyze the parameter amount, calculation amount, intermediate activation, and KV cache of the transformer model
- FlashAttention: accelerates calculations, saves video memory, and provides IO-aware precise attention
- What is the speed optimization principle of FlashAttention? , where Civ, marsggbo All answers are good
- FlashAttention diagram (how to speed up Attention),FlashAttention algorithm detailed explanation
Creation and revision records
- 10.6, when I explained the "Principle and Structure of FlashAttention: Reduce Memory Access and Improve Computational Speed" in the article "Deployment/Fine-tuning/Implementation of Two Generations of ChatGLM", I felt that it would become longer and longer, so I put the FlashAttention-related content in this article. in a blog
- 10.7, major revisions part 1
- 10.8, major revision to Section 2.2 of Part 2
- 10.9, Section 2.2 has been repeatedly revised to maximize readability
Section 2.2.1.1 has been repeatedly revised: Comprehensive understanding of block calculation attention tiling through 23 formulas a>
Repeatedly revise section 2.2.1.3: Forward calculation algorithm of Flash Attention algorithm