A thorough understanding of FlashAttention and FlashAttention2: one of the technologies that allows the context length of large models to exceed 32K

Preface

There are two factors that made this article possible

  • The first factor is that when I led the LLM project team in Changsha to do the paper review GPT project, we encountered many engineering problems ( LLM project More, you will gradually find that there is no secret in the current model. The technical architecture/direction selection is no secret. In the end, it is the continuous optimization of various engineering details), such as data issues, and then For example, the problem of the context length of the large model itself
    The former has been solved. For details, please see this article " Source code interpretation and fine-tuning of academic paper GPT: from chatpaper, gpt_academic to the third part of July paper review GPT"
    But the latter is relatively troublesome because there are more than 10,000 papers in the review corpus The length is basically more than 10,000 words. From the previous articles in this blog, we can know that the context length of most models basically does not exceed 8K.
    Model Corresponding context length Paper review performance (any length within 8K is not enough)
    GPT3.5 4-16K (unified to 16K on 11.7 days later) The 16K effect is yet to be tested
    In addition, the 16K fine-tuning interface of 3.5 was opened on November 7, 2023
    GPT4 8K-32K (upgraded to 128K on November 7th) To be tested
    Calls 2048
    LLaMA2 4096
    LLaMA2-long(23rd September 27th edition论文) 32K The effect is yet to be tested
    LongAlpaca-7B/13B/70B based on LongLoRA technology 32K or more The effect is yet to be tested
    Baichuan-7B/13B、Baichuan 2-7B/13B 4096
    ChatGLM-6B 2000
    ChatGLM2-6B 8-32K The effect of 32K is yet to be determined
  • The second factor is that this article was originally part of the content of ChatGLM2-6B and was compiled together with the content of the first generation ChatGLM-6B. One of the more prominent features of ChatGLM2-6B is that it supports 32K context, and ChatGLM2 is a 32K context implemented based on FlashAttention technology
    Therefore, in order to explain clearly the principles related to FlashAttention, FlashAttention2, etc., the previous article became longer and longer, so I specifically included FlashAttention-related The content is independently extracted into this article

    As for LLaMA2-long and LongAlpaca-7B/13B/70B based on LongLoRA technology, they will be explained in the next blog

This article, like other large model-related articles in this blog, pays great attention to readability.

  1. For example, in order to continuously improve the readability, this article will be revised repeatedly in the near future, carefully focusing on the level, wording, and even typesetting and punctuation of the title. If it is not easy to understand, I would rather not write it.
  2. If you do not understand a certain content or a certain formula in a certain section, please feel free to leave a message in the comments of this article, and be sure to revise it in time so that you can understand (Friendly reminder , this article assumes that you are already familiar with transformer.If you are not familiar with transformer, it is recommended to read this article first: Transformer popular notes: From Word2Vec and Seq2Seq, gradually understand GPT and BERT, especially the third part)

Part 1: Transformer’s space-time complexity and standard attention issues

FlashAttention is a new attention algorithm that is IO-aware, fast and memory-efficient and was proposed by Stanford and the State University of New York in June 2020. "The corresponding paper is :FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness, this isits GitHub address

What kind of problem does it want to solve?

  1. First of all, the maximum sequence length of input and output of large language models such as GPT3, LLaMA, ChatGLM, and BLOOM is only 2048 or 4096. What is the difficulty in extending to longer sequences? The essential reason is that the computational complexity and space complexity of the transformer model are both O(N^2)​, where N​ is the sequence length
  2. In this way, FlashAttention proposes a precise attention that accelerates calculations, saves video memory and IO awareness, and can effectively alleviate the above problems.

The large open source model LLaMA launched by Meta and the large open source model Falcon launched by the United Arab Emirates both use Flash Attention to accelerate calculations and save video memory. Currently, Flash Attention has been integrated into pytorch2.0, and open source frameworks such as triton and xformer have also been integrated and implemented.

1.1 Transformer computational complexity——Self-Attention layer and MLP layer

To put it simply, the computational complexity is proportional to the square of the sequence lengthN^2. You can look at a small example. For example, the sizes of the two multiplied matrices are ( N \times d) and (d \times N), one way to calculate matrix multiplication is to use each row of the first matrix and each column of the second matrix to do ​dot multiplication​

Because we need to do dot multiplication of each row of the first matrix with each column of the second matrix, a total of N^2 dot multiplications are required. Each dot multiplication requires d multiplications, so the total complexity is \mathrm O(N^2d)

To understand it accurately, when the input batch size is b​ and the sequence length is N​,
l The calculation amount of the layer transformer model is l *\left(24 b N d^{2}+4 b N^{2} d\right)​, d​ represents the dimension of the word vector or the dimension of the hidden layer (the hidden layer dimension is usually equal to the word vector dimension)

But how is this result calculated step by step? Next, let’s break down this calculation process in detail.

1.1.1 Computational complexity of Self-Attention layer

First of all, we know that the transformer model consists of l​ identical layers, each layer is divided into two parts: self-attention block and MLP block

The model parameters of the self-attention layer have two parts, one part is Q​, K​, IN ​The weight matrixW_Q, W_K, W_V and bias, the other part is the output weight matrixW_O​and offset, the final result is:8bNd^2 + 4bN^2d

How is it calculated specifically?

  1. The first step is to calculateQ​, K​, IN
    That is, Q=x W_{Q}, K=x W_{K}, V=x W_{V}
    The input and output shapes of this matrix multiplication are [b, N, d] \times[d, d] \rightarrow[b, N, d]
    calculation The amount is:3 * 2 b N d^{2}=6 b N d^{2}
  2. CalculationQ K^T
    Partial import and export shape
    \left[b, h e a d \_n u m, N, p e r \_h e a d \_h i d d e n \_s i z e\right]​ \times​ \left[b, h e a d \_n u m, p e r \_h e a d \_h i d d e n \_s i z e\right. , N]\rightarrow\left[b, h e a d \_n u m, N, N\right]
    Calculation amount:2bN^2d
  3. Calculate the weight onIN​ score \cdot V
    The input and output shapes of this part of the matrix multiplication are < /span>The calculation amount is:
    \left[b, h e a d \_n u m, N, N\right] \times\left[b, h e a d \_n u m, N, p e r \_h e a d \_h i d d e n \_s i z e\right]​ \rightarrow\left[b, h e a d \_n u m, N, p e r \_h e a d \_h i d d e n \_s i z e\right]
    2bN^2d
  4.  Linear mapping after attention, the input and output shapes of matrix multiplication are[b, N, d] \times[d, d] \rightarrow[b, N, d]
    The amount of calculation is2bNd^2

    The final output result of the self-attention layer is
    x_{o u t}=\operatorname{softmax}\left(\frac{Q K^{T}}{\sqrt{d}}\right) \cdot V \cdot W_{o}+x

1.1.2 Computational complexity of MLP layer

The MLP block consists of 2 linear layers, which ends with16bNd^2

How is it calculated?

Generally, the first linear layer maps the dimensions from d to 4d, and the second linear layer maps the dimensions from 4d​Mapped tod
x=f_{\text {gelu }}\left(x_{\text {out }} W_{1}\right) W_{2}+x_{\text {out }}

  1. The weight matrix of the first linear layer W_1 has a shape of [d,4d]​, which is equivalent to changing the dimension from d​ Map to 4d​, the input and output shapes of matrix multiplication are [b, N, d] \times[d, 4 d] \rightarrow[b, N, 4 d]​, and the calculation amount of is  8bNd^2
  2. The weight matrix of the second linear layer W_2​ has a shape of [4d,d]​, which is equivalent to changing the dimension from 4d​Mapped to d​, the input and output shapes of matrix multiplication are [b, N, 4 d] \times[4 d, d] \rightarrow[b, N, d]​, and the calculation amount of is 8bNd^2

Adding up the calculation amounts shown in bold in all the above tables, the calculation amount of each transformer layer is approximately24 b N d^{2}+4 b N^{2} d

1.1.3 Calculation amount of logits:2bNdV

In addition, another major part of the calculation is the calculation of logits (after all, the number of parameters of the word embedding matrix is ​​also large), mapping the hidden vector to the vocabulary size. To put it bluntly, the word vector dimension is usually equal to the hidden layer dimension h​, the parameter amount of the word embedding matrix is ​​Vh​, and the weight matrix of the final output layer usually shares parameters with the word embedding matrix " Explain, as Teacher Du Qiyue said, this is an important point in the transformer. Parameter sharing can reduce the amount of parameters. The word embedding matrix is ​​[vocab_size, hidden_size], and the output layer matrix is ​​[hidden_size, vocab_size], can be shared
. The input and output shapes of its matrix multiplication are [b, N, d] \times[d, V] \rightarrow[b, N, V]​, and the calculation amount is < /span> 2bNdV

Therefore, for a l​ layer transformer model, when the input data shape is [b,N]​, the calculation amount of one training iteration is the above The synthesis of three parts, namely:
l *\left(24 b N d^{2}+4 b N^{2} d\right)+2 b N d V

1.2 Transformer’s space complexity——Self-Attention layer and MLP layer

The size of the memory activated in the middle isl *\left(34 b N d+5 b N^{2} a\right)​ , where a​ is the number of attention heads

Large models usually use mixed precision training during the training process, and the intermediate activation values ​​are generally float16 or bfloat16 data types. When analyzing the memory usage of intermediate activation, it is assumed that the intermediate activation value is saved in float16 or bfloat16 data format, and each element occupies 2 bytes. The only exception is that the mask matrix of the dropout operation only occupies 1 bytes per element. In the following analysis, the unit is bytes, not the number of elements.

Each transformer layer contains a self-attention block and an MLP block, and each corresponds to a layer normalization connection.

1.2.1 Intermediate activation of Self-Attention block

The calculation formula of the self-attention block is as follows:

Q=x W_{Q}, K=x W_{K}, V=x W_{V}
x_{o u t}=\operatorname{softmax}\left(\frac{Q K^{T}}{\sqrt{d}}\right) \cdot V \cdot W_{o}+x

Finally, the memory size occupied by the intermediate activation of the self-attention block is:11 b N d+5 b N^{2} a

How is it calculated specifically?

  1. ForQ,K,V , their common input x needs to be saved, which is the intermediate activation. The shape of the input x is [b, N, d], the number of elements is bNd, and the size of the video memory occupied by is 2 * b N d=2 b N d
  2. For Q K^{T} matrix multiplication, the intermediate activation Q,K needs to be saved, and the shape of both tensors is [b,N,d], The total size of the video memory occupied is2 * 2 * b N d=4 b N d
  3. For the \text { softmax () } function, the input of the function needs to be saved Q K^{T} , occupies a video memory size of 2 b N^{2} aa< a i=4>, where  represents the number of attention heads
    \text { score }=\operatorname{softmax}\left(\frac{Q K^{T}}{\sqrt{d_{k}}}\right)

    where the shape of
    Q is: The shape of is: , the number of elements is , and the occupied video memory is a>\left[b, h e a d \_n u m, N, p e r \_h e a d \_h i d d e n \_s i z e\right]
    K^{T}\left[b, h e a d \_n u m, p e r \_h e a d \_h i d d e n \_s i z e, N\right]
    Q K^{T}\left[b, h e a d \_n u m, N, N\right]b N^{2} a2 b N^{2} a
  4. After has calculated the \text { softmax () } function, the dropout operation will be performed. A mask matrix needs to be saved. The shape of the mask matrix is ​​the same as Q K^{T} , and the size of the video memory occupied by is b N^{2} a
  5. The attention calculated on IN, that is, \text { score } \cdot V, needs to be saved \text { score } with a size of < a i=4>; and , the size is , . The total size of the video memory occupied by the two is 2 b N^{2} aIN2 b Nd2 b N^{2} a+2 b N d
  6. computes the output mapping and a dropout operation. Input mapping needs to save its input, the size is 2 b N d; dropout needs to save the mask matrix, the size is \text { bNd }, both occupy video memory The total size is3 b N d

Therefore, by adding the above intermediate activations, the intermediate activation of the self-attention block occupies a video memory size of11 b N d+5 b N^{2} a

1.2.2 Intermediate activation of MLP block

The calculation formula of the MLP block is as follows: x=f_{\text {gelu }}\left(x_{\text {out }} W_{1}\right) W_{2}+x_{\text {out }}. Finally, for the MLP block, the intermediate activation value that needs to be saved is 19 b N d

How is it calculated specifically?

  1. The first linear layer needs to save its input and occupies a video memory size of 2 b N d
  2. The activation function needs to save its input and occupies a video memory size of8 b N d
  3. The second linear layer needs to save its input and occupies a video memory size of 8 b N d
  4. Finally, there is a dropout operation, which needs to save the mask matrix and occupies a video memory size of\text { bNd }

1.2.3 Two intermediate activations of layer norm that need to be saved

In addition, the self-attention block and the MLP block correspond to a layer normalization respectively. Each layer norm needs to save its input, the size is 2bNd, the intermediate activation that needs to be saved for 2 layer norm is 4bNd

To sum up,the intermediate activation that each transformer layer needs to save occupies a video memory size of34 b N d+5 b N^{2} a

For the l layer transformer model, there is also an embedding layer and the final output layer. The embedding layer does not require intermediate activation. In general, when the hidden dimension h is relatively large and the number of layers l is deep, the intermediate activation in this part is very small and can be ignored

Therefore, For the l layer transformer model, the memory size occupied by the intermediate activation can be approximated as \left(34 b N d+5 b N^{2} a\right) * l  "< /span>" "Analysis of parameters, calculations, intermediate activations, and KV cache of the transformer model"For more analysis, see this article

Through the content of the above two sections, we can see that The calculation amount and storage complexity of the transformer model increase with the sequence length N​ size. This limits the maximum sequence length of large language models  grows quadraticallyN

Secondly, GPT4 has expanded the maximum sequence length N​ to 32K, and Claude has expanded the maximum sequence length N​ to 100K. These efforts must be Some optimization methods have been adopted to reduce the complexity of the native transformer. How to optimize it specifically?
We know that each transformer layer is divided into two parts: the self-attention block and the MLP block, but the 4bN^2d​ item in the above calculation and the intermediate activation The5bN^2a​ items are all generated by the self-attention block and have nothing to do with the MLP block

1.3 Two problems with Standard Attention: high memory usage and high number of HBM reads and writes

  1. To review, the calculation process of the attention mechanism in the transformer is ( Again, if you have forgotten the details related to the transformer, it is recommended to read this first: Transformer popular notes, if you forget what softmax is, review this article:How to understand Word2Vec in a popular way):

    \operatorname{Attention}(Q, K, V)=\operatorname{softmax}\left(\frac{Q K^{\top}}{\sqrt{d}}\right) V
    Among them, Q, K, V \in R^{N \times d}​, where N​ is the sequence length, d​ is the length of each attention head Dimension, the output can be recorded as O \in R^{N \times d}​ 
  2. The above formula can be broken down into the following three steps:

    S=Q K^{\top} \in R^{N \times N}

    P=\operatorname{softmax}(S) \in R^{N \times N}

    O=P V \in R^{N \times d}

    In the standard attention implementation, S, P \in R^{N \times N}​ must be written back to the HBM (This HBM will be explained soon below ), occupying O\left(N^{2}\right)​ memory, usually N \gg d
    For example, for GPT2, ​ is much larger than the memory required by ​  In short, the memory required by the attention matrix ​ ​, ​; for GPT3, N = 1024​, d = 64N = 1028d = 128
    P, SO\left(N^{2}\right)Q, K, V, OO(N d)
  3. The figure below shows the implementation process of standard attention

    There are a total of eight HBM matrix read and write operations. These eight read and write operations are:
    The first line reads \mathrm {Q,K} twice and writes \mathrm{S} once , a total of three read and write operations
    The second line reads once for \mathrm{S} and writes once for \mathrm{P} , a total of two read and write operations
    The third line reads \mathrm {P,V} twice, writes \mathrm{O} once, and reads and writes three times in total

Add some background knowledge

  1. Although there have been many approaches to approximate attention that attempt to reduce the computational and memory requirements of attention. For example, sparse approximation and low-rank approximation methods reduce the computational complexity to linear or sublinear sequence length.
  2. However, these approximate attention methods have not been widely used. Because these methods focus too much on reducing FLOPS (number of floating point calculations) and ignore the memory access overhead of IO reading and writing, they do not effectively reduce running time (wall-clock time).
  3. In short, in modern GPUs, the computing speed has far exceeded the video memory access speed.The bottleneck of most computing operations in the transformer is video memory access. For operations with limited video memory, IO awareness is very important, because video memory reading and writing take up most of the running time

The memory of the GPU is composed of multiple memories of different sizes and different read and write speeds. The smaller the memory, the faster the reading and writing speed. For A100-40GB, the memory classification chart is as follows

  • SRAM memory is distributed on 108 streaming multi-processors, each processor is 192K in size, totaling 192 * 108 K B=20,736 K M=20 M B​ 
    is equivalent to calculation block, but has small memory
  • High Bandwidth Memory HBM (High Bandwidth Memory), which is what we often call video memory, has a size of 40GB. The read and write speed of SRAM is 19TB/s, while the read and write speed of HBM is only 1.5TB/s, which is less than 1/10 of SRAM
    which is equivalent to slow calculation but large memory

In short, the computational complexity and space complexity of the self-attention block, the core component of the transformer, are the quadratic power of the sequence length  and for self-attention block, except that the two large matrix multiplications are computationally limited (, ) , others are memory-limited point-by-point operations ( , such as the mask operation on ​, ​'s softmax operation and ​'s dropout operation. The performance of these point-by-point operations is limited by memory bandwidth and will slow down the running time. ) That is, there are two problems in the standard attention implementation: N
Q K^{\top}P VSSP

  1. takes up a lot of video memory. During the process, the complete attention matrix is ​​instantiated P, S \in R^{N \times N}​, resulting in O\left(N^{2}\right)​ memory requirements
  2. HBM has many reads and writes, which slows down the running time (wall-clock time)

The following Memory-efficient Attention in Section 2.1 and Flash Attention in Section 2.2 are to solve the above two problems respectively.

Part 2 Forward pass of FlashAttention: Memory-efficient Attention/Flash Attention

2.1 Memory-efficient Attention: Reduce the memory complexity from square to linear, but the number of HBM accesses is still square

In the attention calculation process, the main challenge in saving video memory is that the softmax and K.V columns are coupled. The method is to calculate the normalization factor of softmax separately to achieve decoupling

  1. In order to simplify the analysis, ignore the step of "subtracting the maximum value" when calculating softmax.
    Note the Q > Column  ’s  column is , there are i  Define the normalization factor of softmax as:q_{i} \in R^{d}KjK_{j} \in R^{d}S_{i j}=q_{i}^{\top} k_{j} \in R

    L_{i}=\sum_{j} e^{q_{i}^{\top} k_{j}} \in R
  2. Denote v_{j} \in R^{d} as the  th column vector of IN, then output  The th column vector  is: jOio_i
    o_{i}=P_{i:} V=\sum_{j} P_{i j} v_{j}=\sum_{j} \frac{e^{q_{i}^{\top} k_{j}}}{L_{i}} v_{j}
  3. After calculating the normalization factorL_i, it can be obtained by repeated accumulation \frac{e^{q_{i}^{\top} k_{j}}}{L_{i}} v_{j} o_i

In this way, the calculation order is changed through the memory-efficient attention mechanism. Compared with Standard Attention, the memory-efficient attention mechanism reduces the memory complexity from O(N^2) Reduced toO(N) 

This method has been used in "Online normalizer calculation for softmax" and "Self-attention Does Not Need O\left(n^{2}\right) Memory", which is called "lazy softmax". This method avoids instantiating the complete attention matrixS,P, thereby achieving the purpose of saving video memory. However, the number of HBM accesses is still O(N^2), so the running time has not been reduced

2.2 Flash Attention: Reduce the number of HBM reads and writes through kernel fusion,Avoid frequent reading and writing of data from HBM

As mentioned above

  1. In the standard attention implementation, the performance of attention is mainly limited by memory bandwidth and is memory-limited. Frequently reading and writing matricesN \times N from HBM is the main bottleneck affecting performance
  2. Although approximate attention methods such as sparse approximation and low-rank approximation reduce computational FLOPs, for memory-limited operations, the bottleneck of running time is the time taken to read and write data from HBM. Reducing the computational load cannot effectively reduce the running time. time(wall-clock time)
  3. For memory-limited standard attention, Flash Attention is IO-aware, and the goal isAvoid frequently reading and writing data from HBM

Therefore, it is very important to reduce the number of reads and writes to HBM and effectively utilize higher-speed SRAM for calculations. For operations whose performance is limited by memory bandwidth, a common method of acceleration is kernel fusion. Typical methods of this operation are as follows: For three steps:

  1.  Each kernel loads input data from the low-speed HBM into the high-speed SRAM
  2. In SRAM, calculations are performed
  3. After the calculation is completed, the calculation results are written from SRAM to HBM.

In this way, you can avoid repeatedly executing "read input data from HBM, perform calculations in SRAM, and finally write the calculation results to HBM" and merge multiple operations into one operation.< a i=1>Reduce the number of times to read and write HBM(It should be noted that model training usually affects the effect of operator fusion, because in order to calculate the gradient in backward pass, it is usually necessary to transfer certain Intermediate results are written to HBM)

Some students may not understand the above explanation. In fact, the principle is very simple, that is, the following two sentences

  1. If writing SRAM back to HBM just to (re)load it to calculate softmax
  2. Then it is possible to save it in SRAM, perform all the intermediate steps, and then write the final result back to the HBM

The former is shown on the left side of the picture below, and the latter is shown on the right side of the picture below (picture belowSource)

2.2.1 Comprehensive explanation of block computing attention tiling——Kernel integration needs to meet the memory size of SRAM, but the SRAM memory is too large small

Although through kernel fusion, multiple operations are merged into one operation and the use of high-speed SRAM for calculation can reduce the number of reads and writes to HBM, thereby effectively reducing the running time of memory-limited operations. But there is a problem

  1. The memory size of SRAM is limited, and it is impossible to calculate the complete attention at once. Therefore, the calculation must be performed in blocks so that the memory required for block calculation does not exceed the size of SRAM.
    Quite Because the memory is limited--> Reduce the number of HBM reads and writes--> Kernel fusion--> Meet the memory size of SRAM--> Block calculation, so the block size block_size cannot be too large, otherwise it will cause OOM
  2. What is the difficulty of block calculation?
    The calculation process of the attention mechanism is "matrix multiplication--> scale --> mask --> softmax --> dropout --> matrix multiplication". Matrix multiplication and The block calculation of point-by-point operations (scale, mask, dropout) is easy to implement, but the difficulty lies in the block calculation of softmax. Sincecalculates the normalization factor (denominator) of softmax, it is necessary to obtain the complete input data, so it is more difficult to perform block calculation

How do you understand the sentence above "Since the complete input data needs to be obtained when calculating the normalization factor (denominator) of softmax, it is more difficult to perform block calculations"?

Let’s first review the calculation formula of softmax.

  1. Considering the vector \left[x_{1}, x_{2}, \cdots, x_{d}\right]​, the calculation process of native softmax is as follows:
    \operatorname{softmax}\left(x_{i}\right)=\frac{e^{x_{i}}}{\sum_{j=1}^{d} e^{x_{j}}}
  2. In actual hardware, the range of floating point number representation is limited
    For float32 and bfloat16, when x = 89​, Therefore, all deep learning frameworks now use the "safe softmax" calculation method is defined as the maximum value in and Therefore, in order to avoid the problem of numerical overflow and ensure numerical stability, calculations are usually Will "subtract the maximum value", called "safe softmax" e^x​ will become very large or even become inf, causing the problem of data overflow


    m(x)\left[x_{1}, x_{2}, \cdots, x_{d}\right]
    m(x)=\max \left(\left[x_{1}, x_{2}, \ldots, x_{d}\right]\right)


    \quad \operatorname{softmax}\left(x_{i}\right)=\frac{e^{x_{i}-m(x)}}{\sum_{j=1}^{d} e^{x_{j}-m(x)}}
  3. When training a language model, the cross-entropy loss function is usually used. The cross-entropy loss function is equivalent to executing the log_softmax function first, and then calculating the negative log likelihood function
    and when calculating log_softmax, "subtracting the maximum value" will also be executed, which not only avoids numerical overflow, improve numerical stability, and also speed up calculations
    \log \left(\operatorname{softmax}\left(x_{i}\right)\right)=\log \left(\frac{e^{x_{i}-m}}{\sum_{j=1}^{d} e^{x_{j}-m}}\right)=x_{i}-m-\log \left(\sum_{j=1}^{d} e^{x_{j}-m}\right)

In summary, to calculate how much attention a particular ith token in the input sequence pays to other tokens in the sequence, all these scores need to be readily available in SRAM (here will explode. Soon (sequence length) can be 1,000 or even 100,000 tokens, x_j represents), but the capacity of SRAM is limited. NN^2

In short, the main idea of ​​tiling is to calculate attention in chunks. The difficulty of block calculation lies in the block calculation of softmax . The columns of softmax and K are coupled. By introducing Two additional statistics m(x),l(x) are used for decoupling (the former is similar to the maximum score, and the latter is similar to the sum of exp scores) to implement block calculation

2.2.1.1 Comprehensive understanding of block calculation attention tiling through 23 formulas

Let’s start from the beginning and sort it out comprehensively (the following 23 formulas are explained fromthis)

  1. S=Q K^{\top} \in R^{N \times N}

  2. P=\operatorname{softmax}(S) \in R^{N \times N}

  3. O=P V \in R^{N \times d}

  4. Considering the vector \left[x_{1}, x_{2}, \cdots, x_{d}\right]​, the calculation process of the native softmax is as follows:
    \operatorname{softmax}\left(x_{i}\right)=\frac{e^{x_{i}}}{\sum_{j=1}^{d} e^{x_{j}}}
    Among them, the molecule e^{x_{i}} is to the vector < The element in /span> The sum of all elements is taken after the exponent, which ensures that the output of the softmax function is a probability distribution, that is, the sum of all elements is 1 is the corresponding element in the vector x takes the exponent, and the denominator i\sum_{j=1}^{d} e^{x_{j}}x

  5. m(x)=\max \left(\left[x_{1}, x_{2}, \ldots, x_{d}\right]\right)
     m(x) is defined as the maximum value among \left[x_{1}, x_{2}, \cdots, x_{d}\right]
  6. \quad f(x) =\left[\begin{array}{lll} e^{x_{1}-m(x)} & \ldots & e^{x_{d}-m(x)} \end{array}\right]
    f(x) is a new vector in which each term is equivalent to the numerator of the standard softmax of Equation 4, that is, based on each term of e^{x_i}, in its exponential termx_i is subtracted from one\left[x_{1}, x_{2}, \cdots, x_{d}\right]m(x)
  7. \quad \ell(x) = \sum_{i} f(x)_{i}
    l(x) is the summation term in the denominator of softmax. For the convenience of description later, the summation term in Formula 7 will be called "EXP summation term
  8. \quad \operatorname{softmax}(x):=\frac{f(x)}{\ell(x)}
    Consider a vector of size 2dx \in \mathbb R^{2d}, divide it into two blocks: x=[x^{(1)},x^{(2)}]
    Among them x^{(1)},x^{(2)} \in \mathbb R^d
    In other words, the sub-vector x^{(1)} is the first half of the original vector x, and the sub-vector x^{(2)}Is the second half of the original vectorx

    Assume that in the block calculation, it is processed firstx^{(1)} and then x^{(2)}
    Then use Formula 5 to Formula 8 to calculate its "local" for the sub-vectorx^{(1)}. The calculation process is as follows: Formula 9-12 Show\mathrm {softmax}
  9. m\left(x^{(1)}\right)=\max \left(\left[x_{1}^{(1)}, x_{2}^{(1)}, \ldots, x_{d}^{(1)}\right]\right)
  10. f\left(x^{(1)}\right)=\left[e^{x_{1}^{(1)}-m\left(x^{(1)}\right)}, \ldots, e^{x_{d}^{(1)}-m\left(x^{(1)}\right)}\right]
  11. l\left(x^{(1)}\right)=\sum_{i} f\left(x^{(1)}\right)_{i}
  12. \operatorname{softmax}\left(x^{(1)}\right)=\frac{f\left(x^{(1)}\right)}{l\left(x^{(1)}\right)}
    Obviously, the \operatorname{softmax}\left(x^{(1)}\right) obtained so far cannot be regarded as the final result of the sub-vectorx^{(1)}. The reason is very simple
    First, the maximum value subtracted from the exponential term in Formula 10 should be the maximum value of the entire vector, not the sub-vector The maximum value of both, the EXP summation term of the denominator in Formula 12 should be the summation term about the entire vector, not It is just the summation of all elements in the subvector Because the obtained by the above calculation is not the final result, so it will It is called "local" Next we will introduce how to save some additional variable values ​​​​and update > Method of value First, after processing the subvector , save and , saving just these two scalars is much less expensive than saving the entire subvector Secondly, two global scalars need to be saved: and represent the current maximum value, because only has been processed so far, so Temporarily:  represents the global EXP summation item. Because only has been processed so far, so for the time being: and then use a similar method to to process , the following results can be obtained:xm(x)x^{(1)}m\left(x^{(1)}\right)
    xx^{(1)}
    \mathrm {softmax} (x^{(1)})

    x^{(2)}x^{(1)}\mathrm {softmax}
    x^{(1)}m(x^{(1)})l(x^{(1)})x^{(1)}
    m_{max}l_{all}
    m_{max}x^{(1)}m_{max}=m(x^{(1)})
    l_{all}x^{(1)}l_{all}=l(x^{(1)})
    x^{(1)}x^{(2)}
  13. m\left(x^{(2)}\right)=\max \left(\left[x_{1}^{(2)}, x_{2}^{(2)}, \ldots, x_{d}^{(2)}\right]\right)

  14. f\left(x^{(2)}\right)=\left[e^{x_{1}^{(2)}-m\left(x^{(2)}\right)}, \ldots, e^{x_{d}^{(2)}-m\left(x^{(2)}\right)}\right]

  15. l\left(x^{(2)}\right)=\sum_{i} f\left(x^{(2)}\right)_{i}

  16. \operatorname{softmax}\left(x^{(2)}\right)=\frac{f\left(x^{(2)}\right)}{l\left(x^{(2)}\right)}
    In the same way, the softmax obtained by Formula 16 is also local rather than global.
    But after processing x^{(2)}, you can use (), as shown in the following formulas 17 and 18:) and  (x^{(2)} information to update the two previously saved global scalars m_{max}m_{max}=m(x^{(1)})l_{all}l_{all}=l(x^{(1)})

  17. m_{m a x}^{n e w}=\max \left(\left[m_{\max }, m\left(x^{(2)}\right)\right]\right)
    The meaning of formula 17 is very simple: the updated global maximum value is the maximum value of the previous maximum value m_{max} and x^{(2)} < a i=3>The larger onem\left(x^{(2)}\right)

  18. l_{\text {all }}^{n e w}=e^{m_{\max }-m_{\max }^{\text {new }}} l_{\text {all }}+e^{m_{x^{(2)}}-m_{\max }^{n e w}} l\left(x^{(2)}\right)
    Formula 18 is a method of updating the global EXP summation term of Wait a minute, how did this come about? Shouldn't it be? Taking as an example, we say is "local" because So far, only the information of has been used. It is needed to update to "global".
    l_{\text {all }}^{n e w}=l_{\text {all }}+l\left(x^{(2)}\right)

    l(x^{(2)})l(x^{(2)})l(x^{(2)})x^{(2)}l(x^{(2)})m^{new}_{max}

    Expand the calculation formula 15 ofl(x^{(2)}) slightly, i.e.l\left(x^{(2)}\right)=\sum_{i} f\left(x^{(2)}\right)_{i}, we can get:

  19. l\left(x^{(2)}\right)=\sum_{i} e^{x_{i}^{(2)}-m\left(x^{(2)}\right)}
    It can be seen that the reason whyl\left(x^{(2)}\right) is "local" rather than "global" is that the max value it subtracts is "local", so you only need to replace this max value with global
    For this purpose, l\left(x^{(2)}\right) can be down-transformed to become global

  20. i.e.
    \begin{aligned} l^{\text {new }}\left(x^{(2)}\right) & =l\left(x^{(2)}\right) \cdot e^{m\left(x^{(2)}\right)-m_{\text {max }}^{\text {new }}} \\ & =\sum_{i} e^{x_{i}^{(2)}-m\left(x^{(2)}\right)} \cdot e^{m\left(x^{(2)}\right)-m_{m a x}^{\text {new }}} \\ & =\sum_{i} e^{x_{i}^{(2)}-m_{\text {max }}^{\text {new }}} \end{aligned}
    At this time, l(x^{(2)}) is updated to: "global"
    This formula explains, , where a> a>, which can be updated is defined by formula 14, that is, , Let’s look at the numerator part first Since the current numerator and denominator are both local, they need to be updated to the global level, it can be seen that according to formula 16 that is, and the softmax value can also be directly updated is based on the above update method respectively, and then sums them up to get the current The EXP summation term to the global and Returning to formula 18, we can see It first uses this global update method to update   represents the current maximum value , represents the current maximum value corresponding to to be "global", just multiply it by a term: When you need to update a le^{m - m^{new}_{max}}mlm^{new}_{max}
    l_{all}l(x^{(2)})

    l
    \operatorname{softmax}\left(x^{(2)}\right)=\frac{f\left(x^{(2)}\right)}{l\left(x^{(2)}\right)}{f\left(x^{(2)}\right)} = \operatorname{softmax}\left(x^{(2)}\right) \times {l\left(x^{(2)}\right)}


    f(x^{(2)})f(x^{(2)})f\left(x^{(2)}\right)=\left[e^{x_{1}^{(2)}-m\left(x^{(2)}\right)}, \ldots, e^{x_{d}^{(2)}-m\left(x^{(2)}\right)}\right]
  21. i.e.
    \begin{aligned} f^{n e w}\left(x^{(2)}\right) & =f\left(x^{(2)}\right) \cdot e^{m\left(x^{(2)}\right)-m_{m a x}^{n e w}} \\ & =\left[e^{x_{1}^{(2)}-m\left(x^{(2)}\right)}, \ldots, e^{x_{d}^{(2)}-m\left(x^{(2)}\right)}\right] \cdot e^{m\left(x^{(2)}\right)-m_{m a x}^{\text {new }}} \\ & =\left[e^{x_{1}^{(2)}-m_{m a x}^{n e w}}, \ldots, e^{x_{d}^{(2)}-m_{m a x}^{n e w}}\right] \end{aligned}
    When comparing f(x^{(2)}) before and after the transformation, the conclusion obtained from formula 20 above is once again confirmed, that is: if you want to f(x^{(2)})From a local value to a global value,  just multiply it by a term: e^{m - m^{new}_{max}}, where a>mlm^{new}_{max}

    and let’s look at the denominator part< /span>. This can be done by the following formula: to l(x^{(2)}), we actually only need to replace the denominator from   represents the current maximum value , represents the current maximum value corresponding to l(x^{(2)})l^{new}_{all}

  22. \frac{\operatorname{softmax}\left(x^{(2)}\right) \cdot l\left(x^{(2)}\right)}{l_{\text {all }}^{\text {new }}}
    wherel_{\text {all }}^{\text {new }} is calculated from Formula 18

    Okay, here comes the question
    Question 1 Many friends on the Internet also I have expressed doubts about this, that is, why the denominator in formula 22 is l_{\text {all }}^{\text {new }} instead of l^{\text {new }}\left(x^{(2)}\right)

    Answer: The reason is very simple. Consider why we use softmax: it is a vector Each element of is assigned a probability value between 0 and 1, such that the sum of these probabilities is 1
    When we say "global", we want to provide a probability value for each element of the entire data set. Probability is assigned to an element, not just to a subset of the data set
    So when you have a data stream split into two parts: x^{(1)} and < /span>, and it still only considers this in Formula 20 is just the global version of Answer: Question 2 The next question may come again, and some students may ask it soonGlobal effect after merging and alone, you must consider combined) softmax, you can't just consider and , in order to calculate the entire data stream ( and calculate its softmax, then, you see x^{(2)}. You first see Question 3 Formula 20 and Formula 19 are only used, then it What is the difference between the two? Formula 19: The maximum value here is , that is, local maximum. This means that for this data block, we compare each element with the maximum value within itFormula 20: Here the maximum value is​The update can be achieved as follows: Finally, combining Formula 21 and Formula 22, So, the main difference is that they use The reference maximum values ​​of are different: Equation 19 uses a local maximum, while Equation 20 uses a more global maximum. This transformation is for numerical stability, ensuring that we will not encounter numerical overflow problems when we calculate the exponent of e with the maximum of all elements observed so far . This means that we compare each element of and , which is the global maximum of x^{(1)}x^{(2)}x^{(1)}x^{(2)}x^{(2)}x^{(1)}x^{(2)}


    l^{\text {new }}\left(x^{(2)}\right)
    l^{new}(x^{(2)})l(x^{(2)})x^{(2)}

    x^{(2)}
     l\left(x^{(2)}\right)=\sum_{i} e^{x_{i}^{(2)}-m\left(x^{(2)}\right)}
    m(x^{(2)})x^{(2)}
    l^{\text {new }}\left(x^{(2)}\right)=\sum_{i} e^{x_{i}^{(2)}-m_{\max }^{\text {new }}}
    m^{new}_{max}x^{(1)} x^{(2)}x^{(2)}



    \operatorname{softmax}\left(x^{(2)}\right)

  23. \operatorname{softmax}^{(n e w)}\left(x^{(2)}\right)=\frac{\operatorname{softmax}\left(x^{(2)}\right) \cdot l\left(x^{(2)}\right) \cdot e^{m\left(x^{(2)}\right)-m_{\text {max }}^{\text {new }}}}{l_{\text {all }}^{\text {new }}}
    Look carefully at Formula 23. When we update the value of x^{(2)}, we use the additional saved quantities mentioned earlier: local value of a>, local maximum from Equation 15, global maximum from Equation 13 a> values ​​in FlashAttention This is the essence of dynamically updating or All update processes do not need to use the vector value of Update, value of to correct the in the first three items above with In the same way, You can replace , from Formula 18 Global EXP summation term, from Formula 17, local EXP summation term from Equation 16\mathrm {softmax}
    x^{(2)}\mathrm {softmax}\mathrm {softmax} (x^{(2)})
    x^{(2)}l(x^{(2)})
    x^{(2)}m(x^{(2)})
    m^{new}_{max}
    l^{new}_{all}

    x^{(2)}x^{(1)}x^{(1)}\mathrm {softmax}x^{(1)}x^{(2)}
    \mathrm {softmax}

The above is actually an incremental calculation process.

  1. We first calculate the local softmax value of a block and then store it
  2. When the next block is processed, the old softmax value can be updated based on the new global maximum value and global EXP summation term at this time, and then the next block is processed, and then updated
  3. After all blocks are processed, the softmax values ​​of all blocks at this time are "global"
2.2.1.2 A brief summary of calculating attention tiling in blocks

Maybe your CPU has been dry burned. In order to alleviate the brain burn, let’s finally summarize the above process through a simple example.

For two vectors x^{(1)}, x^{(2)} \in R^{d}, the softmax calculation of the decoupled splicing vector x=\left[x^{(1)}, x^{(2)}\right] \in R^{2 d}:

m(x)=m\left(\left[x^{(1)} x^{(2)}\right]\right)=\max \left(m\left(x^{(1)}\right), m\left(x^{(2)}\right)\right)

\quad f(x)=\left[e^{m\left(x^{(1)}\right)-m(x)} f\left(x^{(1)}\right) \quad e^{m\left(x^{(2)}\right)-m(x)} f\left(x^{(2)}\right)\right]

\ell(x)=\ell\left(\left[x^{(1)} x^{(2)}\right]\right)=e^{m\left(x^{(1)}\right)-m(x)} \ell\left(x^{(1)}\right)+e^{m\left(x^{(2)}\right)-m(x)} \ell\left(x^{(2)}\right)

\quad \operatorname{softmax}(x)=\frac{f(x)}{\ell(x)}

By maintaining two additional statistics m(x),l(x) , softmax can be calculated in blocks. It should be noted that GPU multi-threading can be used to calculate the softmax of multiple blocks in parallel at the same time. In order to make full use of hardware performance, the calculation of multiple blocks is not serial but parallel

I seem to see a vague feeling of anxiety on your face. It's okay. Don't worry. July understands. The simple formula is relatively obscure after all. Let's use an example to vividly illustrate how to calculate softmax in blocks.

Calculate softmax for vector [1,2,3,4] and divide it into two blocks [1,2] and [3,4] for calculation

Calculate block 1:

\begin{array}{l} m_{1}=\max ([1,2])=2\\ \begin{array}{c} f_{1}=\left[e^{1-2}, e^{2-2}\right]=\left[e^{-1}, e^{0}\right] \\ l_{1}=\sum f_{1}=e^{-1}+e^{0} \\ o_{1}=\frac{f_{1}}{l_{1}}=\frac{\left[e^{-1}, e^{0}\right]}{e^{-1}+e^{0}} \end{array} \end{array}

Calculate block 2:

\begin{array}{l} m_{2}=\max ([3,4])=4\\ \begin{array}{c} f_{2}=\left[e^{3-4}, e^{4-4}\right]=\left[e^{-1}, e^{0}\right] \\ l_{2}=\sum f_{2}=e^{-1}+e^{0} \\ o_{2}=\frac{f_{2}}{l_{2}}=\frac{\left[e^{-1}, e^{0}\right]}{e^{-1}+e^{0}} \end{array} \end{array}

Combine to get the complete softmax result:

\begin{array}{l} m=\max \left(m_{1}, m_{2}\right)=4\\ f=\left[e^{m_{1}-m} f_{1}, e^{m_{2}-m} f_{2}\right]=\left[e^{-3}, e^{-2}, e^{-1}, e^{0}\right]\\ l=e^{m_{1}-m} l_{1}+e^{m_{2}-m} l_{2}=e^{-3}+e^{-2}+e^{-1}+e^{0}\\ o=\frac{f}{l}=\frac{\left[e^{-3}, e^{-2}, e^{-1}, e^{0}\right]}{e^{-3}+e^{-2}+e^{-1}+e^{0}} \end{array}

2.2.1.3 Forward calculation algorithm of Flash Attention algorithm

Simplifying the analysis while ignoring mask and dropout, the forward calculation process of the Flash Attention algorithm is as follows

As can be seen from the above figure, the algorithm performs an outer loop on the dimension of K.V and an inner loop on the dimension of Q (and in In the code implementation of triton, an outer loop is used in the dimension of Q , and in < Make an inner loop on the dimension of a i=6>)K.V

For the sake of detail, I will explain the above 16 lines of code line by line. To facilitate everyone’s understanding, I will quote a flow chart drawn by marsggbo on Zhihu. You can refer to this flow chart to improve your understanding of the relevant code.

First, there are basic conditions:

\text { Matrices } \mathbf{Q}, \mathbf{K}, \mathbf{V} \in \mathbb{R}^{N \times d} \text { in HBM, on-chip SRAM of size } M
Among them, N is the sequence length, d​ is the dimension of each attention head, SRAM is the size forM

  1. Set block sizes B_{c}=\left\lceil\frac{M}{4 d}\right\rceil, B_{r}=\min \left(\left\lceil\frac{M}{4 d}\right\rceil, d\right)
    Calculate row/column block sizes. Why ceil(M / 4 d)? Because the query, key, and value vectors are d-dimensional, we also need to combine them into the output d-dimensional vector. So this size basically allows us to maximize the capacity of the SRAM with q k v and 0 vectors

    Take GPT2 and A100 as an example:
    The SRAM size of A100 is M=192KB=196608B
    in GPT2 N=1024, is , and the dimension of the intermediate result is d=64, the corresponding dimension of Q,K,VN\times ×d=1024\times ×64S,PN\times N=1024 \times 1024

    B_{c}=\lceil 196608 / 4 / 64\rceil=768 ; \quad B_{r}=\min (768,64)=64


  2. Initialize the output matrix with all 0sO, which will act as an accumulator
    lSimilar to the abovel(x), Its purpose is to save the cumulative denominator of softmax - the sum of exp scores
    m is similar to the above m(x), which saves the maximum score line by line and is initialized to -inf. Because we are going to Max operator it, whatever the Max of the first block is, it will definitely be greater than -inf


  3. Divide Q, K and IN into chunks < a i=4> Specifically, is divided into blocks along the row direction, and the size of each block is Divided into blocks along the row direction, the size of each block is and

    QT_rB_{r} \times d
    K.VT_cB_{c} \times d
    T_{c}=\lceil 1024 / 768\rceil=2 ; T_{r}=\lceil 1024 / 64\rceil=16


  4. Split O, l, m into blocks
    where O and Q are blocks The size is the same, and it is also divided into T_r blocks along the row direction. The size of each block is B_{r} \times d
    . As for the vector l and the vector m is divided into T_r blocks, and the sub-vector size of each block is B_r

    . Combining the above two steps 3 and 4, we can get each The relationship between the blocks is as follows

  5.  for 1 ≤ j ≤ Tc do
    start across Column loop (i.e. outer loop, controlled by T_c, from previous column to next column), i.e. across key/value vectors, i.e. traversal < a i=7>, a total of cycles timesK.VT_{c}=2

  6.      Load Kj , Vj from slow HBM to on-chip fast SRAM.
         Will K_j and V_j blocks Load from HBM to SRAM (their size is B_{c} \times d=768 \times d). At this point in time we still have 50% of SRAM unoccupied (dedicated to Q and O)
         

  7.        for 1 ≤ i ≤ Tr do
          start across lines Inner loop (from the previous row to the next row), that is, across the query vector, a total of T_{r}=16 times, you can only traverse Q,O,l,m

  8.             Load Qi, Oi, ℓi, mi from HBM to on-chip SRAM.
                )和 ()块到()和< a i=8> ()Additional to SRAMQ_iB_r \times d = 64 \times dO_iB_r \times d = 64 \times dl_iB_rm_iB_r

                Here you need to ensure that l_i and m_i can be loaded into SRAM (including all intermediate variables)

  9.             On chip, compute \mathbf{S}_{i j}=\mathbf{Q}_{i} \mathbf{K}_{i}^{T} \mathbb{R}^{B_{r} \times B_{c }}, ImmediatelyC_{64 \times 768}=A_{64 \times d} \times B_{d \times 768}

                This step calculates Q_i (B_r \times d) and K_j transpose (d \times B_c ) to obtain the block Attention Score\mathbf{S}_{i j}=\mathbf{Q}_{i} \mathbf{K}_{i}^{T} \mathbb{R}^{B_{r} \times B_{c }}. The Attention Score obtained in the standard Transformer calculation is a matrix of N \times N, as shown in the figure below As shown (in the figure N=12, B_r = 3 , B_c =2)
                 
                 Standard Transformer need to calculate the Attention Score includes the entire matrix (gray), and the Attention Score calculated for each block is shown in the blue and orange areas in the figure

                 For another example, assume that the outer loop index is j (j=3), the inner loop index is i (i=2), N is 25, and the block size is 5. The following is the result just calculated (assuming a 1-based index) :

                 
  10.             On chip, compute\tilde{m}_{i j}=\operatorname{rowmax}\left(\mathbf{S}_{i j}\right) \in \mathbb{R}^{B_{r}}, \tilde{\mathbf{P}}_{i j}=\exp \left(\mathbf{S}_{i j}-\tilde{m}_{i j}\right) \in \mathbb{R}^{B_{r} \times B_{c}} \text { (pointwise) }\tilde{\ell}_{i j}= \operatorname{rowsum}\left(\tilde{\mathbf{P}}_{i j}\right) \in \mathbb{R}^{B_{r}}

                Use the scores calculated in the previous step to calculate \tilde{m}_{i j}, \tilde{\ell}_{i j} and \tilde{\mathbf{P}}_{i j}
                serve to calculate on the score  >S_{ij}, calculate the maximum value in each row of itm_{ij} = \mathrm {rowmax({S}}_{ij}) \in \mathbb R^{B_r}

                Based on \hat m_{ij}, calculate the exponential term (normalized - take the row maximum and subtract it from the row score, then EXP): \hat P_{ij} = \mathrm{exp}(\mathrm{S}_{ij} - \hat m_{ij})\in \mathbb R^{B_r\times B_c}

                Then based on \hat P_{ij}, calculate the EXP summation term (row-wise sum of matrix P): \hat l_{ij} = \mathrm {rowsum} (\hat P_{ij}) \in \mathbb R^{B_r}

  11.             On chip, compute m_{i}^{\text {new }}=\max \left(m_{i}, \tilde{m}_{i j}\right) \in \mathbb{R}^{B_{r}}, \ell_{i}^{\text {new }}=e^{m_{i}-m_{i}^{\text {new }}} \ell_{i}+e^{\tilde{m}_{i j}-m_{i}^{\text {new }}} \tilde{\ell}_{i j} \in \mathbb{R}^{B_{r}}
                This step is to calculate m_{i}^{\text {new }} and \ell_{i}^{\text {new }}, for example, also
    can reuse the chart above:
                
               m_{i}contains the row-wise maximum of all previous blocks (j=1 & j=2 , indicated in green), \tilde{m}_{i j} contains the row-wise maximum value of the current block (indicated in yellow). In order to get m_{i}^{\text {new }} we just need to take a maximum value between \tilde{m}_{i j} and m_{i}, \ell_{i}^{\text {new }} Also similar to
               
     and above, use formulas 17, and 18, , to update m_i respectively. andl_i, have the same meaning

  12.             In order to better understand the formula in this line, you must first understand that the purpose of calculating multiple lines together is Batch calculation \mathbf{O}_{i} \leftarrow \operatorname{diag}\left(\ell_{i}^{\text {new }}\right)^{-1}\left(\operatorname{diag}\left(\ell_{i}\right) e^{m_{i}-m_{i}^{\text {new }}} \mathbf{O}_{i}+e^{\tilde{m}_{i j}-m_{i}^{\text {new }}} \tilde{\mathbf{P}}_{i j} \mathbf{V}_{j}\right) \text { to HBM } In the figure, each small block has multiple rows (3 rows in the figure), but there will be no interaction between the data between the rows. It is just a Batch calculation strategy. . The real meaning of blocking is on the column, because softmax is performed along the column direction

    S_{ij}

                 So for the convenience of understanding, it can be imagined that B_r is equal to 1, that is, only one block of size (1 \times B_c) in the above figure is calculated each time

                 Based on the above simplified method, let’s look at the entire softmax update process. We use S_i to represent the Attention Score of each row, and SM_i to represent the of each row.\mathrm {softmax}

                 

                 Since the Batch calculation is not considered now, the Attention Score of each processing is a vector, as shown in the figure above S_{11}, we first use formula 5 to formula 8 to calculate it Local\mathrm {softmax}
                 
                 getsSM_1. At this time, only the first two positions in SM_1 have values, corresponding to S_{11}local\mathrm {softmax} value

                 We then do the same with each row below it (the first two columns of the green part)

                 Then process S_{12}, first use formula 5 to formula 8 to calculate its local \mathrm {softmax}, and then use formula 23 a i=3> updates (note that from line 11 above, it can be seen that is equivalent to ): < /span>SM_1\ell_{i}^{\text {new }}l_{\text {all }}^{\text {new }}

                 \mathrm{SM}_1^{(new)} = \frac{\mathrm{SM}_1 \cdot l_1 \cdot e^{m_1 -m^{new}_{1}}}{l^{new}_ {1}} + \frac{\hat P_{12} \cdot e^{m_{12} -m^{new}_{1}}}{l^{new}_{1}}                                     (recorded as Formula 24)

    \hat P_{12}\quad f(x) =\left[\begin{array}{lll} e^{x_{1}-m(x)} & \ldots & e^{x_{d}-m(x)} \end{array}\right]

                 When S_{13} is processed, continue to apply formula 24 to update:

                 \mathrm{SM}_1^{(new)} = \frac{\mathrm{SM}_1 \cdot l_1 \cdot e^{m_1 -m^{new}_{1}}}{l^{new}_ {1}} + \frac{\hat P_{13} \cdot e^{m_{13} -m^{new}_{1}}}{l^{new}_{1}}                                     (recorded as Formula 25)

                 Let’s go one step further and try to update the output O_1 directly, not just the \mathrm {softmax} value \mathrm {SM}_1. The method is actually very simple. Just multiply the corresponding value of after each dynamic update of \mathrm {softmax}: IN

                 {O}_1^{(new)} = \frac{\mathrm {O}_1 \cdot l_1 \cdot e^{m_1 -m^{new}_{1}}}{l^{new}_{1 }} + \frac{\hat P_{12} \cdot e^{m_{12} -m^{new}_{1}}}{l^{new}_{1}} \cdot V_2                                  (recorded as Formula 26)

    V_2S_{12}

                 Comparing Formula 26 with the pseudocode above, we can see that the formula in the pseudocode is just the matrix version of Formula 26. At this point, you can see that block Self-Attention calculation can be achieved using Equation 26

  13.              Write \ell_{i} \leftarrow \ell_{i}^{\text {new }}, m_{i} \leftarrow m_{i}^{\text {new }} \text { to HBM }
                 Updatel_iJapanesem_i

  14.        end for 

  15.   end for

  16.   Return O.

2.2.2 Recalculation

As mentioned above, model training will affect the effect of kernel fusion. In order to calculate gradients through backward transfer, some intermediate results usually need to be written back to HBM during forward calculation. This will generate additional HBM read and write times and slow down the process. operation hours. Therefore, Flash Attention does not save a large intermediate result matrix for backward pass

In the standard attention implementation, when calculating the gradient of Q,K,V in backward pass, the intermediate matrix of N \times N needs to be used S,P, but these two matrices are not saved. The trick here is to recalculate, save two statisticsm(x),l(x), quickly recalculate Attention on high-speed SRAM during backward pass, and recalculate the attention matrix in blocks a>S,P. Compared with the standard attention method, which reads a large intermediate attention matrix from the HBM, the recomputation method is much faster.

In general, Flash Attention avoids instantiating the complete N \times N attention by adjusting the calculation order of attention and introducing two additional statistics for block calculation. Matrix S,P reduces the memory complexity from O\left(N^{2}\right) to O\left(N\right) . In addition, for standard attention with limited memory, Flash Attention greatly reduces the number of HBM accesses through kernel fusion and block calculation. Although the recalculation in the backward pass adds additional computational FLOPs and reduces the running time, Faster calculations (7.6 for GPT2)

2.2.3 kernel fusion

In order to simplify the analysis, the mask and dropout operations were ignored when introducing attention above. The details of Flash Attention forward pass are introduced in detail below. Given the inputQ, K, V \in R^{N \times d}, calculate the attention outputO^{N \times d}

\begin{array}{c} S=\tau Q K^{\top} \in R^{N \times N} \\ S^{\text {masked }}=M A S K(S) \in R^{N \times N} \\ P=\operatorname{softmax}\left(S^{\text {masked }}\right) \in R^{N \times N} \\ P^{\text {dropped }}=\operatorname{dropout}\left(P, p_{d r o p}\right) \in R^{N \times N} \\ O=P^{\text {dropped }} V \in R^{N \times d} \end{array}

Among them, \ can is the scaling factor of softmax, typically such as \frac{1}{\sqrt{d_{k}}} . The MASK operation sets some elements in the input to −∞. After calculating the softmax, it becomes 0, and other elements remain unchanged

The main difference between the causal-lm structure and the prefix-lm structure is that the MASK matrix is ​​different. \text { dropout }(x, p) acts on each element of x point by point, and sets the element to 0 with the probability of p , with The probability of a> 1-p sets the element to\frac{x}{1-p}

tiling block calculation allows us to usea CUDA kernel to perform all operations of attention. Load input data from HBM, perform all calculation operations (matrix multiplication, mask, softmax, dropout, matrix multiplication) in SRAM, and then write the calculation results back to HBM. Fusion of multiple operations into one operation through kernel fusion avoids repeatedly reading and writing data from HBM

Kernel integration is shown in the figure below, the picture comes fromhttps://www.bilibili.com/video/BV1Zz4y1q7FX/

Considering the mask and dropout operations, the forward calculation process of the complete Flash Attention algorithm is as follows:

// To be updated..


Part 3 FlashAttention2

// To be updated

References and Recommended Reading

  1. Transformer popular notes: gradually understand GPT and BERT from Word2Vec and Seq2Seq
  2. Analyze the parameter amount, calculation amount, intermediate activation, and KV cache of the transformer model
  3. FlashAttention: accelerates calculations, saves video memory, and provides IO-aware precise attention
  4. What is the speed optimization principle of FlashAttention? , where Civ, marsggbo All answers are good
  5. FlashAttention diagram (how to speed up Attention),FlashAttention algorithm detailed explanation

Creation and revision records

  1. 10.6, when I explained the "Principle and Structure of FlashAttention: Reduce Memory Access and Improve Computational Speed" in the article "Deployment/Fine-tuning/Implementation of Two Generations of ChatGLM", I felt that it would become longer and longer, so I put the FlashAttention-related content in this article. in a blog
  2. 10.7, major revisions part 1
  3. 10.8, major revision to Section 2.2 of Part 2
  4. 10.9, Section 2.2 has been repeatedly revised to maximize readability
    Section 2.2.1.1 has been repeatedly revised: Comprehensive understanding of block calculation attention tiling through 23 formulas a>
    Repeatedly revise section 2.2.1.3: Forward calculation algorithm of Flash Attention algorithm

Guess you like

Origin blog.csdn.net/v_JULY_v/article/details/133619540