この記事の目的はフラッシュ アテンションについて詳しく説明することですが、なぜフラッシュ アテンションについて説明するのでしょうか? FlashAttention はアテンションの計算を並べ替えるアルゴリズムであるため、近似を行わずにアテンションの計算を高速化し、メモリ フットプリントを削減できます。したがって、現在の LLM モデルの高速化としては非常に優れたソリューションです。この記事では古典的な V1 バージョンを紹介し、最新の V2 にはその他の最適化が施されていますが、ここでは当面は紹介しません。FlashNote の V1 バージョンは 5 ~ 10 倍高速であると主張されているため、それがどのように実装されているかを調べてみましょう。
導入
論文のタイトルは次のとおりです。
「FlashAttendant: IO 認識を備えた高速でメモリ効率の高い Exact Attendance」
メモリ効率 通常のアテンション (シーケンスの長さは 2 次、O(N²)) と比較して、FlashAttention は二次/線形 N (O(N)) です。また、これは注意メカニズムの近似ではありません (たとえば、スパースまたは低ランクの行列近似法)。その出力は「従来の」注意メカニズムと同じです。通常の注意と比較して、FlashAttendant の注意は「知覚」されます。
これは、基盤となるハードウェアのメモリ階層の知識を活用します (GPU などですが、他の AI アクセラレータも機能するはずです。ここでは例として GPU を使用しています)。一部の[近似]方法では、計算要件をシーケンス長において線形またはほぼ線形に削減しますが、その多くはメモリ アクセス (IO) のオーバーヘッドを無視して FLOP を削減することに重点を置いています。
長年の開発を経て、GPU の FLOPS はメモリ スループット (TB/秒) よりも速く成長しています。メモリのボトルネックは真剣に考慮する必要があります。FLOPS とメモリ スループットは密接に組み合わせる必要があり、ハードウェアのギャップがあるため、ソフトウェア レベルで作業のバランスをとる必要があります。
演算とメモリアクセスの比率に応じて、演算は次の 2 種類に分類されます。
- 計算上の制約: 行列の乗算
- メモリ制約: 要素操作 (アクティブ化、ドロップアウト、マスキング)、マージ操作 (ソフトマックス、レイヤーノルム、合計など)
現在の AI アクセラレータ (GPU) では、メモリ サイズによって制限されます。「主に要素ごとの演算で構成されている」ため、より正確に言えば、注意の算術密度はそれほど高くありません。
この写真を見てみましょう:
ご覧のとおり、マスキング、ソフトマックス、およびドロップアウトは時間のかかる演算であり、行列の乗算ではありません (FLOPS のほとんどが matmul にある場合でも)。メモリは単一の成果物ではなく、本質的に階層構造になっており、一般的なルールとして、メモリが高速であるほど高価になり、容量が小さくなります。
FlashAttend の注意が「認識されている」と上で述べたことは、要約すると、HBM (高帯域幅メモリ) よりもはるかに高速な SRAM を利用して、2 つの間の通信を少なくすることです。
A100 を例に挙げます。
A100 GPU には、帯域幅 1.5 ~ 2.0 TB/秒の 40 ~ 80 GB の高帯域幅メモリ (HBM) が搭載されていますが、108 個のストリーム プロセッサにはそれぞれ 192 KB の SRAM が搭載されており、帯域幅は約 19 TB/秒であると推定されます。
サイズは大幅に小さくなりましたが、速度は 10 倍向上していることがわかります。そのため、SRAM をいかに効率的に使用するかが高速化の鍵となります。標準アテンションの実装の背後にある計算を見てみましょう。
標準実装では、HW の動作方法がほとんど考慮されていません。基本的に、HBM ロード/ストア操作をコスト 0 として扱います (「io 対応」ではありません)。
まず、この実装を (時間とメモリの点で) より効率的に行う方法を検討します。最も簡単な方法は、冗長な HBM 読み取り/書き込みを削除することです。
ソフトマックスを計算するために S を (再) ロードするためだけに S を HBM に書き戻すのはどうでしょうか。そうすれば、それを SRAM に保持し、すべての中間ステップを実行して、最終結果を HBM に書き戻すことができます。
カーネルは基本的に「GPU オペレーション」を派手に表現したものです (単なる関数である CUDA の入門に関する以前の投稿を参照してください)。融合を使用すると、複数の操作を融合できます。したがって、HBM から 1 回だけロードし、融合された演算を実行して、結果を書き戻します。そうすることで通信のオーバーヘッドが軽減されます。
ここには「マテリアライゼーション」(物質化・実体化)という専門用語もあります。これは、上記の標準的なアテンション実装では、完全な NxN 行列 (S, P) が割り当てられているという事実を指します。以下では、メモリの複雑さを O(N²) から O(N) に直接削減する方法を見ていきます。
フラッシュ アテンションは基本的に 2 つの主要なポイントに要約されます。
タイリング(前方および後方パス中に使用) - 基本的に、NxN ソフトマックス/スコア マトリックスをチャンクにタイリングします。
再計算(後方パスでのみ使用)
アルゴリズムは次のとおりです。
上記では多くの名詞について言及しましたが、まだ理解していないかもしれません。それは問題ではありません。アルゴリズムを 1 行ずつ説明していきましょう。
フラッシュ アテンション アルゴリズム
タイリング手法の主な障害はソフトマックスです。ソフトマックスはすべてのスコア列を結合する必要があるためです。
分母が見えますか? それが問題です。
入力シーケンス内の特定の i 番目のトークンがシーケンス内の他のトークンに対してどの程度の注目を集めているかを計算するには、これらのスコア (ここでは z_j で示されている) がすべて SRAM ですぐに利用できる必要があります。
しかし、SRAMの容量には限界があります。N (シーケンス長) は 1000 トークンまたは 100000 トークンにすることもできます。したがって、N² は非常に早く爆発します。そこでこの論文では、ソフトマックスの計算を小さなブロックに分割しても、最終的にはまったく同じ結果が得られるというトリックを使用しています。
以前の B スコア (x_1 から x_B) を取得して、それらのソフトマックスを計算するだけです。その後、反復を通じて正しい結果に「収束」します。これらのブロックごとのソフトマックス数値を賢明な方法で組み合わせて、最終結果が実際に正確になるようにします。以下のような方法:
基本的に、最初の 2 つのブロック (サイズ B) に属するスコアのソフトマックスを計算するには、各ブロックの 2 つの統計値、m(x) (最大スコア) と l(x) (合計) を追跡する必要があります。経験値スコアの)。その後、正規化係数を使用してそれらをシームレスにブレンドできます。
ここでは主に基本的な代数演算を説明しますが、f(x) と l(x) の項を拡張して e^x を掛けると、いくつかの項が打ち消し合うため、ここでは書きません。このロジックは最後の (N/B) ブロックまで再帰的に継続され、N 次元的に正しいソフトマックス出力が得られます。
このアルゴリズムの詳細を説明するために、サイズ 1 のバッチ (つまり、単一のシーケンス) と単一のアテンション ヘッドを仮定します。これは、後で拡張されます (GPU 間の単純な並列化によって - 詳細は後ほど)。ドロップアウトとマスキングは後で追加するため、現時点では無視します。
計算を開始します。
初期化: HBM の容量は GB 単位で測定されるため (例: RTX 3090 には 24 GB の VRAM/HBM があり、A100 には 40 ~ 80 GB など)、Q、K、V の割り当ては問題ありません。
ステップ1
行/列のブロック サイズを計算します。なぜ ceil(M / 4 d) なのか? クエリ、キー、および値のベクトルは d 次元であるため、それらを結合して出力の d 次元ベクトルにする必要もあります。したがって、このサイズにより、基本的に qkv および 0 ベクトルを使用して SRAM の容量を最大化できます。
たとえば、M = 1000、d = 5 と仮定します。この場合、ブロック サイズは (1000/4*5)=50 となります。したがって、q、k、v、o ベクトルの 50 ブロックを一度にロードすることで、HBM/SRAM 間の読み取り/書き込みの数を減らすことができます。
B_r についても、最小限の操作を実行するために d を使用する理由がよくわかりません? 知っている人がいたら、コメントしてアドバイスしてください。
ステップ2:
出力行列 O をすべて 0 で初期化します。これはアキュムレータとして機能します。同様に、その目的はソフトマックスの累積分母 (exp スコアの合計) を保持することです。M (行ごとの最大スコアを保持) は、Max 演算子を実行するため -inf に初期化されます。そのため、最初のブロックの Max が何であれ、それは間違いなく -inf より大きくなります。
ステップ 3:
ステップ 1 のブロック サイズは、Q、K、および V をブロックに分割します。
ステップ 4:
O、l、m をブロックに分割します (Q と同じブロック サイズ)。
ステップ5:
列全体、つまりキー/値ベクトル全体のループを開始します (上の図の外側のループ)。
ステップ6:
K_j ブロックと V_j ブロックを HBM から SRAM にロードします。現時点では、SRAM の 50% がまだ空いています (Q および O 専用)。SRAM は次のようになります。
ステップ 7:
行全体、つまりクエリ ベクトル全体で内部ループを開始します。
ステップ8:
Q_i (B_r xd) ブロックと O_i (B_r xd) ブロック、l_i (B_r) と m_i (B_r) を SRAM にロードします。
ここで、l_i と m_i が (すべての中間変数を含む) SRAM にロードできることを確認する必要があります。これは CUDA の知識かもしれませんが、計算方法がわかりません。関連する情報をお持ちの場合は、メッセージを残してください。
ステップ9:
Q_i (B_r xd) と K_j 転置 (dx B_c) の間のドット積を計算して、スコア (B_r x B_c) を取得します。nxns(score) マトリックス全体を「具体化」するわけではありません。
外側のループ インデックスが j (j=3)、内側のループ インデックスが i (i=2)、N が 25、ブロック サイズが 5 であると仮定すると、次の計算結果が得られます (1 ベースのインデックス付けを仮定) :
つまり、入力シーケンス内のトークン 11 ~ 15 のうちのトークン 6 ~ 10 の注意スコアです。ここで重要な点は、これらは正確なスコアであり、決して変更されないということです。
ステップ 10:
前のステップで計算されたスコアを使用して、m_i_j、l*i_j、および P~*i_j を計算します。M ~_i_j は行ごとに計算され、上の各行の最大の要素が検索されます。
次に、要素ごとの演算を適用して P~_i_j を取得します。
正規化 - 行の最大値を取得し、行のスコアからそれを減算し、EXP
l~_i_j は行列 P の行ごとの合計です。
ステップ 11:
m_new_i と l_new_i を計算します。上の図を再利用するのも非常に簡単です。
M_i には、以前のすべてのブロックの行ごとの最大値が含まれます (j=1 および j=2、緑色で示されます)。M_i_j には、現在のブロックの行ごとの最大値 (黄色で表示) が含まれます。m_new_i を取得するには、m_i_j と m_i の間の最大値を取るだけでよく、l_new_i も同様です。
ステップ 12 (最も重要):
これはアルゴリズムの最も難しい部分です。
これにより、行列形式で行単位のスカラー乗算を行うことができます。スカラー s(N) の列と行列 a(NxN) がある場合、diag(s)*a を実行すると、基本的に行 a とこれらのスカラーを要素ごとに乗算することになります。
式 1 (便宜上、ここに再度貼り付けました):
ステップ 12 の最初の項目 (緑色の下線) は、同じ行ブロック内の現在のブロックに先行するブロックの現在のソフトマックス推定を更新します。if j=1 (これはこの行の最初のブロックです。
最初の項には diag(l_i) が乗算され、前の反復で除算された同じ定数が相殺されます (この定数は O_i に隠されています)。
P~_i_j 行列と V ベクトル ブロック (V_j) を直接乗算していることがわかるため、式の 2 番目の項 (黄色の下線) を削除する必要はありません。
e^x 項は、前の反復から m を削除し、これまでの行ごとの最大値を含む最新の推定値 (m_new_i) で更新することにより、行列 P~_i_j & O_i を変更するために使用されます。
これが私の段階的な分析です (実際には 5 分しかかかりません。お役に立てば幸いです!)
重要なのは、これらの外側の e 項と P/O 行列内の e 項が削除されるため、常に最新の m_new_1 推定値が得られるということです。
3 回目の反復も同様で、正しい最終結果が得られました。
思い出してください: これは最終的な O_i の現在の推定値にすぎません。上の画像内のすべての赤いブロックを反復処理した後でのみ、最終的に正確な結果を得ることができます。
ステップ13
蓄積された最新の統計 (l_i および m_i) を HBM に書き込みます。それらの次元数は B_r であることに注意してください。
ステップ13、14、15、1
ネストされた for ループの終わり、O(Nxd) には最終結果、つまり各入力トークンのアテンションの重みのベクトルが含まれます。
簡単なまとめ
このアルゴリズムは、FlashAttendant よりも 2 ~ 4 倍高速で、64k のシーケンス長までスケールアップできるスパース アテンション アルゴリズムである「ブロック スパース FlashAttendant」に簡単に拡張できます。ブロック形式のマスク マトリックスを使用することで、次のことが可能です。上記のネストされた for ループで一部のロード/ストアをスキップし、次の図のようにスパース係数を比例的に保存できるようにします。
ここで、複雑さについて簡単に説明しましょう。
複雑さの分析
スペース: Q、K、V、O (Nxd)、l、m (N) が HBM に割り当てられます。これは 4 N d + 2*N に等しくなります。定数を削除し、d も定数であり、通常は N よりもはるかに小さいことがわかると (例: d={32,64,128}、N={1024,...,100k})、O(N) のスペースが得られ、これが役に立ちます。最大 64k シーケンス長まで拡張可能 (さらに、ALiBi などの他の「トリック」も追加)。
時間: 時間計算量の分析はここでは厳密には行われませんが、適切な指標である HBM アクセス数を使用します。
論文の説明は以下の通り。
この数値はどのようにして取得されたのでしょうか? ネストされた for ループを分析してみましょう。
ブロックサイズはM/4dです。これは、ベクトルが N/(M/4d) ブロックに分割されることを意味します。これを 2 の累乗にすると (行/列のブロックを走査しているため)、O(N²d²/M²) となります。
ブロック全体を一度にフェッチすることはできず、大がかりな分析を行うと、これは標準的な注意よりもそれほど優れたものではないと思われるかもしれませんが、一般的な数値では、これによりアクセス数が 9 分の 1 に削減されます (上記の紙のスクリーンショットによると)。
私たちの擬似アルゴリズムは、バッチ サイズを 1 と仮定して、単一ヘッド アテンションに焦点を当てています。今、私たちは拡大を始めます
多面的な注意
実際、batch_size > 1 および num_heads > 1 にスケールするのはそれほど難しくありません。
アルゴリズムは基本的に単一のスレッド ブロック (CUDA プログラミング用語) によって処理されます。このスレッド ブロックは、単一のストリーミング マルチプロセッサ (SM) 上で実行されます (たとえば、A100 にはそのようなプロセッサが 108 個あります)。計算を並列化するには、batch_size * num_heads スレッド ブロックのみを異なる SM 上で並列実行する必要があります。この数値がシステム上で使用可能な SM の数に近づくほど、使用率は高くなります (各 SM は複数のスレッド ブロックを実行できるため、理想的には複数)。
誤差逆伝播法
GPU メモリの占有に関して、もう 1 つの重要な点はバックプロパゲーションです。出力 O (Nxd) とソフトマックス正規化統計 (N) を保存することで、SRAM の Q、K、V (Nxd) ブロックから直接反転できます。行列 S(NxN) と P(NxN) ! したがって、メモリは O(N) に保たれます。こちらはより専門的で、以下のことが理解できますので、詳しい内容については原論文を参照してください。
コード
最後に、フラッシュ アテンションを使用するときに発生する可能性のある問題のいくつかを見てみましょう。ビデオ メモリの操作が関係するため、CUDA について詳しく説明することしかできませんが、CUDA はより複雑です。
これは、OpenAI の Triton のようなプロジェクトの強みです (FlashAttendant の実装を参照してください)。Triton は基本的に DSL (ドメイン固有言語) であり、CUDA と TVM などの他のドメイン固有言語の間の抽象化レベルです。CUDA を直接処理しなくても、(コンパイルが完了すれば) 超最適化された Python コードを作成することが可能です。このようにして、Python コードを任意のアクセラレータにデプロイできます (これは Triton タスクです)。
もう 1 つの良いニュースは、Triton が最近 PyTorch 2.0 に統合されたことです。
また、シーケンス長が 1K を超える場合など、一部のユースケースでは、一部の近似アテンション メソッド (Linformer など) が高速になり始めています。ただし、フラッシュ アテンションのブロック スパース実装は、他のすべてのメソッドよりも優れています。
要約する
なぜ NVIDIA のエンジニアではなく、スタンフォード大学の学生がこの種の最下位最適化のアルゴリズムをリリースしたのか疑問に思ったことはありますか?
考えられる説明は 2 つあると思います。
1. FlashAttend はより簡単で、最新の GPU でのみ実装可能です (元のコードベースは V100 をサポートしていません)。
2. 通常、「部外者」とは、初心者の目で問題を見て、問題の根本を理解し、基本原理から問題を解決できる人を指します。
最後に、まだまとめが必要です
FlashAttend は、BERT 大規模トレーニングで 15% を節約し、GPT トレーニング速度を 2/3 向上させることができ、コードを変更することなく、これは非常に重要な進歩であり、LLM 研究の方向性に新しい進歩を提案します。
用紙のアドレス:
https://avoid.overfit.cn/post/9d812b7a909e49e6ad4fb115cc25cdc1
著者: アレクサ・ゴーディック