FlashAttention

FlashAttention new upgrade! Dr. Stanford rewrote the algorithm alone, and the second generation achieved a speed increase of up to 9 times. Transformer context length epic level increase

Following the explosion of the ultra-fast and memory-saving attention algorithm Flash Attention, the upgraded version 2 is here.

FlashAttention-2 is an algorithm written from scratch to speed up attention and reduce its memory footprint, without any approximations.

Compared with the first generation, FlashAttention-2 is twice as fast.

Even, compared to PyTorch's standard attention, it can run up to 9 times faster.

A year ago, Dr. Tri Dao of StanfordAILab released FlashAttention, which made attention 2 to 4 times faster. Today, FlashAttention has been adopted by many companies and research laboratories and is widely used in most LLM libraries.

Nowadays, with the needs of new use cases such as long document query and writing stories, the context of large language models has become much longer than before. The context length of GPT-4 is 32k, the MPT context length of MosaicML is 65k, and Claude of Anthropic The context length is 100k.

However, expanding the context length of the Transformer is a great challenge, because the runtime and memory requirements of the attention layer at its core are quadratic in the length of the input sequence.

Tri Dao has been working on FlashAttention-2, which is 2 times faster than v1, 5 to 9 times faster than standard attention, and has reached a training speed of 225 TFLOP/s on A100!

Paper address: https://tridao.me/publications/flash2/flash2.pdf

Project address: https://github.com/Dao-AILab/flash-attention

FlashAttention-2: Better Algorithms, Parallelism, and Work Partitioning

End-to-end training of the GPT model at a speed of up to 225 TFLOP/s

Although Flash Attention is already 2-4 times faster than the optimized baseline at the time of release, there is still considerable room for improvement.

For example, FlashAttention is still not as fast as Optimized Matrix Multiplication (GEMM) operations, only reaching 25-40% of the theoretical maximum FLOPs/s (for example, 124 TFLOPs/s on an A100 GPU). How GEMMs are used for convolution

In the past few months, researchers have been developing FlashAttention-2, which has stronger performance indicators than the first generation.

The researchers say that Generation 2 is equivalent to a complete rewrite from scratch, using Nvidia's CUTLASS 3.x and its core library CuTe. In terms of speed, FlashAttention-2 is 2 times faster than the previous version, with a speed of up to 230 TFLOPs/s on the A100 GPU.

When using end-to-end to train a language model like GPT, the researchers achieved a training speed of up to 225 TFLOPs/s (72% FLOP utilization of the model).

Reordering attention calculations

We know that FlashAttention is an algorithm for reordering attention calculations, using tiling and recalculation to significantly speed up calculations, and reduce the memory usage of the sequence length from quadratic to linear. The researchers load the input block from HBM (GPU memory) to SRAM (fast cache) and perform attention on the block, updating the output in HBM.

Since no large intermediate attention matrix is ​​written to HBM, the amount of memory read/write is also reduced, resulting in a 2-4x execution time acceleration.

The following figure is the forward pass diagram of FlashAttention: through tiling and softmax rescaling, researchers operate in modules, avoiding reading or writing from HBM, while obtaining correct output without approximation. However, FlashAttention still suffers from some inefficiencies due to suboptimal division of work between different thread blocks, as well as warps on the GPU—resulting in low occupancy or unnecessary shared memory reads and writes.

Fewer non-matmul FLOPs (non-matrix multiplication floating-point calculations)

The researchers reduced the number of non-matmul FLOPs by adjusting the algorithm of FlashAttention. This is very important because modern GPUs have dedicated compute units (such as tensor cores on Nvidia GPUs) which makes matmul much faster.

For example, the maximum theoretical throughput of A100 GPU FP16/BF16 matmul is 312 TFLOPs/s, but the theoretical throughput of non-matmul FP32 is only 19.5 TFLOPs/s.

Also, each non-matmul FLOP is 16 times more expensive than a matmul FLOP.

So to keep throughput high, researchers want to spend as much time as possible on matmul FLOPs.

The researchers also rewrote the online softmax trick used in FlashAttention to reduce the number of rescaling operations, as well as bounds checking and causal masking operations, without changing the output.

better parallelism

FlashAttention v1 parallelizes processing on batch size and number of parts. The researchers use 1 thread block to process an attention head, and there are (batch_size * head number) thread blocks in total. In forward processing (left), the researchers parallelize workers (thread blocks), each worker is responsible for processing a row block of the attention matrix. During backward processing (right), each worker processes a block of columns of the attention matrix

Each thread block runs on a Streaming Multiprocessor (SM), for example, there are 108 of these on the A100 GPU. This scheduling is effective when this number is large (say ≥ 80), because in this case, almost all computing resources on the GPU can be efficiently used.

In the case of long sequences (usually implying smaller batches or fewer headers), in order to better utilize multiprocessors on the GPU, the researchers additionally parallelized in the dimension of sequence length, making the mechanism obtain significantly accelerated.

better workspace

Even within each thread block, the researchers had to decide how to divide the work among different warps (a set of 32 threads working together). Researchers typically use 4 or 8 warps per thread block, and the partitioning scheme is shown in the figure below.

The researchers improved this partitioning in FlashAttention-2, reducing the amount of synchronization and communication between different warps, thereby reducing shared memory read/write. For each block, FlashAttention splits K and V across 4 warps while keeping Q accessible to all warps. This is called the "sliced-K" scheme.

However, this is not efficient because all warps need to write their intermediate results to shared memory, synchronize, and then add the intermediate results.

And these shared memory read/write will slow down the forward propagation speed in FlashAttention.

In FlashAttention-2, we split Q into 4 warps while keeping K and V accessible to all warps.

After each warp performs matrix multiplication to obtain a slice of QK^T, they simply multiply with the shared V slice to obtain the corresponding output slice.

This eliminates the need for communication between warps. The reduction in shared memory reads and writes increases speed.

New features: header dimensions up to 256, multi-query attention

FlashAttention only supports a maximum head size of 128. Although it is suitable for most models, some models are still excluded.

FlashAttention-2 now supports a header dimension of 256, which means that models such as GPT-J, CodeGen, CodeGen2, and Stable Diffusion 1.x can use FlashAttention-2 to achieve acceleration and save memory.

v2 also supports Multi-Query Attention (MQA) and Grouped Query Attention (GQA). GQA shares a single key and value header for each set of query headers, interpolating between multi-head and multi-query attention

These are variants of attention where multiple query headers point to the same header for key and value to reduce the size of the KV cache during inference and can significantly improve the throughput of inference.

attention benchmark

The researchers measured the runtime of different attention methods on an A100 80GB SXM4 GPU with different settings (with and without causal mask, head dimension 64 or 128). The researchers found that FlashAttention-2 is about 2 times faster than the first generation (including other implementations in the xformers library and Triton). FlashAttention-2 is up to 9 times faster than standard attention implementations in PyTorch. Forward + backward speed on A100 GPU

Simply by running the same implementation on an H100 GPU (without using special instructions to take advantage of new hardware features like TMA and fourth-generation Tensor Cores), the researchers were able to achieve speeds of up to 335 TFLOPs/s. Forward + backward speed on H100 GPU

When used for end-to-end training of GPT-like models, FlashAttention-2 can achieve speeds up to 225TFLOPs/s on the A100 GPU (model FLOPs utilization rate is 72%). whaosoft  aiot  http://143ai.com  

Compared with the already well-optimized FlashAttention model, the end-to-end speedup is further improved by 1.3 times.

future career

The speed is 2 times faster, which means that researchers can train models with 16k context length at the same cost as previously training 8k context models. These models can understand long-form books and reports, high-resolution images, audio, and video.

At the same time, FlashAttention-2 will also accelerate the training, fine-tuning and inference of existing models.

In the near future, the researchers also plan to expand the collaboration to make FlashAttention widely applicable to different types of devices (e.g. H100 GPU, AMD GPU) as well as new data types (e.g. fp8).

Next, the researchers plan to further optimize FlashAttention-2 for the H100 GPU to use new hardware features (TMA, 4th generation Tensor Core, fp8, etc.).

Combining the low-level optimizations in FlashAttention-2 with high-level algorithmic changes (such as local, dilated, block-sparse attention) allows researchers to train AI models with longer contexts.

The researchers are also excited to collaborate with compiler researchers to make these optimization techniques better applicable to programming.

References:

https://princeton-nlp.github.io/flash-atttention-2/

 

Guess you like

Origin blog.csdn.net/qq_29788741/article/details/131798445
Recommended