安定拡散原理の紹介とソースコード解析 (1)

安定拡散原理の紹介とソースコード解析 (1)

序文(本文とは関係ないので省略可)

Stable Diffusion は、Stability AI のオープンソース AI Wensen グラフ拡散モデルです。以前、 Diffusion、Diffusion Model の原理といくつかのアルゴリズムコードを紹介しました.基本的な好奇心を満たした後、棚に置きました.最近のAIGC ははるかに高速です. 私の期待は、特に次の AI によって生成された画像を使い果たした後、Stable Diffusion が最終的に私の視野に戻ってきました:

アルゴリズムエンジニアとして、物事の本質を見抜く目を持つ必要があります.私がこの写真に最初に惹かれたのは、コンテンツではなく、その生成の品質でした.画像は高精細であり、以前見たラフなおもちゃとは比べものにならないほどのディテールに富んでいます. 赤 箱に記された不適合にも欠陥があります. したがって、Stable Diffusion のエンジニアリング フレームワーク全体の原理をさらに詳しく分析することが急務であり、今後、赤枠内の不一致を修復し、AIGC のさらなる発展に貢献できることを期待しています。

概要

Stable Diffusion のフレームワーク全体のソースコードは数万行に及ぶため、すべてを解析する必要はありません。この記事では、「テキストから画像へ」をメイン ラインとして、Stable Diffusion の操作プロセスと各重要なコンポーネント モジュールを調べ、導入部で「総得点」の形式を採用し、最初に全体的なフレームワークを要約し、次にそれぞれを分析します。コンポーネント (DDPM、DDIM など)、およびコード内の一部の非主流ロジック ( など) については、predict_cidsこれらreturn_idsの細部に関する私の見解について説明します。記事の内容は長く、複数の部分に分割する準備ができています。

送信元アドレス:安定拡散

例証する

これまでコード分析の記事をたくさん書いてきましたが、問題に遭遇して読み直すと、目的の場所をすばやく特定し、コードの意図を正確に理解することは依然として非常に困難であることがわかります。抜粋で紹介される実装の詳細が多すぎて、情報伝達の効率が低下するためです。

少し考えた後、トラブルを回避しようとするのをやめ、疑似コードを使用してコア原則を記録することにしました。私は通常、コードを深く分析するときにこの方法を使用します.コードを抽象化するのに時間がかかりますが、有益だと思います. たとえば、DDPM モデルの順方向拡散コードを疑似コードで記述すると、次のような効果があります。

無関係な実装の詳細を削除した後、DDPM の実装は非常に単純であることがわかります.特定の注釈と組み合わせると、簡単かつ迅速に理解でき、人々は全体的かつ包括的な制御の感覚を得ることができます. さらに、コードの実装の詳細をより直感的に表示するために、より多くのブロック図、モデル図などをテキストに追加する必要があります。

WeChat で「Jenny's Algorithm Road」または「world4458」を検索し、WeChat パブリック アカウントをフォローして、元の技術記事の最新の更新を適時に入手できます。

また、Zhihu コラムPoorMemory-Machine Learningも参照できます。今後の記事も Zhihu コラムに掲載される予定です。

安定拡散の全体像

まず、画像を生成するための Stable Diffusion テキストの全体的なフレームワークを見てください (記事の描画は血を吐きます... いつの日か AI が支援できることを願っています)。

上の写真のフレームワークにはたくさんのモジュールがあり、上から下に3つのブロックに分かれています.私はそれらを写真の中でPart 1, 2, 3とマークしました. このフレームワークには、トレーニングとサンプリングの 2 つの段階が含まれています。

  • トレーニング フェーズ (図のパート 1 とパート 2 を参照) には、主に次のものが含まれます。

    1. AutoEncoderKL セルフ エンコーダーを使用して、画像 Image をピクセル空間から潜在空間にマッピングし、画像の暗黙的な表現を学習します. AutoEncoderKL エンコーダーは事前にトレーニングされており、パラメーターは固定されていることに注意してください. このとき、画像のサイズは から に[B, C, H, W]変換されます[B, Z, H/8, W/8]。ここで、潜在空間内の画像のチャネル数Zを表します。このプロセスは Stable Diffusion コードで呼び出されますencode_first_stage
    2. FrozenCLIPEmbedder テキスト エンコーダーを使用して Prompt プロンプト ワードをエンコードし、サイズ (つまり)[B, K, E]埋め込み表現を生成します。ここで、テキストの最大長の最大エンコード長を表し、埋め込みのサイズを表します。このプロセスは Stable Diffusion コードで呼び出されますcontextKEget_learned_conditioning
    3. 画像の暗黙的表現に連続的にノイズを加える順拡散処理 (Diffusion Process) を行う. この処理は UNetModel を呼び出して完了する. UNetModel は画像の潜像とテキスト埋め込みを同時に受け取り, Attention を条件として使用する.トレーニングcontext.contextテキストと画像の一致関係をよりよく学習するためのメカニズム。
    4. 拡散モデルの出力ノイズϵ θ \epsilon_{\theta}ϵ, 計算と実際のノイズの間の誤差が損失として使用され、UNetModel モデルのパラメーターが逆伝播アルゴリズムによって更新されます. AutoEncoderKL と FrozenCLIPEmbedder のパラメーターは、このプロセス中に更新されないことに注意してください.
  • サンプリング フェーズ (図のパート 2 とパート 3 を参照)、つまり、モデル パラメーターを読み込んだ後、プロンプト ワードを入力して画像を出力します。主に次のものが含まれます。

    1. FrozenCLIPEmbedder テキスト エンコーダーを使用して Prompt プロンプト ワードをエンコードし、サイズ[B, K, E]の(つまりcontext) を生成します。
    2. のサイズ[B, Z, H/8, W/8]の、トレーニング済みの UNetModel モデルを使用し、DDPM/DDIM/PLMS およびその他のアルゴリズムに従って T 回反復し、ノイズを継続的に除去し、画像の潜在表現を復元します。
    3. AutoEncoderKL を使用して[B, Z, H/8, W/8]画像の潜在表現 (サイズは ) をデコード (デコード) し、最終的にピクセル空間の画像を復元します。画像サイズは です。このプロセスは[B, C, H, W]Stable Diffusion で呼び出されますdecode_first_stage

上記の紹介の後、Stable Diffusion の全体像をより明確に理解できるようになります. 次に、図に従って、各主要モジュールを理解するために最善を尽くします. FrozenCLIPEmbedder と DPM アルゴリズム (図には書かれていません) を除いて、個人のエネルギーと限られた自由時間に制限されています。

  • UNetModel
  • AutoEncoderKL & VQModelInterface (変分オートエンコーダーでもありますが、図にはありません)
  • DDPM、DDIM、PLMS アルゴリズム

後で簡単に紹介し、学習プロセスを記録します。

重要書類

コードを読む過程で、分厚い論文を読まなければならないことがわかりました。拡散モデルの理論的導出はまだやや複雑ですが、式の導出とコードの実装を組み合わせることで、知識の理解を深めることができます。コードを読むのに大いに役立った論文のリストを次に示します。

重要成分の分析

以下は、Stable Diffusion の重要なコンポーネントの簡単な分析です。主に次のものが含まれます。

  • UNetModel
  • DDPM、DDIM、PLMS アルゴリズム
  • オートエンコーダーKL
  • predict_cidsなど、主流ではないロジックについてreturn_ids話します。

最初に UNetModel 構造を導入して、後続の記事を直接引用できるようにします。

UNetModel の紹介

Stable Diffusion で使用する UNetModel を描画した後は、コードを解析する必要はありません。Stable Diffusion は、UNetModel の Encoder-Decoder 構造を使用して、拡散プロセスを実現し、ノイズを予測します. ネットワーク構造は次のとおりです。

モデルへの入力は、次の 3 つの部分で構成されます。

  • のサイズ[B, C, H, W]画像イメージ。サイズを表すために使用される記号を気にしないように注意してください。それらはインターフェイスと見なされるべきです。たとえば、UNetModel が[B, Z, H/8, W/8]のサイズのCますZHに等しいH/8Wに等しいW/8;
  • サイズ[B,]のタイムステップ
  • サイズ[B, K, E]がテキスト埋め込み表現context。ここで、 は最大エンコード長Kを表し、E埋め込みサイズを表します。

モデルはDownSampleと をUpSampleサンプルのダウンサンプリングとアップサンプリングを行い、最も頻繁に現れるモジュールはResBlockSpatialTransformer、図のそれぞれが前のモジュールからの入力とタイムステップに対応する埋め込みをResBlock受け取ります(timestep_embサイズは構成可能なパラメーターです)。図では、前のモジュールからの入力と(Prompt テキストの埋め込み表現) を受け取り、Cross Attention を使用して、Prompt と画像の一致関係を条件として。ただし、この図では、2 つのモジュールが点線のボックス内に複数の入力を持っていることのみを示しており、他のモジュールは描かれていません)。[B, 4*M]MSpatialTransformercontextcontext

[B, C, H, W]最終モデルの出力サイズは入力サイズと同じであることがわかります。これは、UNetModel が入力と出力のサイズを変更しないことを意味します。

ResBlocktimestep_embeddingおよびcontextの実装をそれぞれ見てみましょう。SpatialTransformer

ResBlockの実装

ResBlock ネットワーク構造図は次のとおりです。画像xと。

timestep_embedding の実装

timestep_embedding の生成方法は次のとおりです。論文 Tranformer (Attention is All you Need) の方法を使用します。

プロンプトテキスト埋め込みの実装

つまりcontext、の実現です。Prompt はコーディングに CLIP モデルを使用しています.CLIP モデルについては詳しく調べていないので,当面深く読むつもりはありません.後で機会があれば追加します.コードは生成されます.事前トレーニング済みの CLIP を使用context:

SpatialTransformer の実装

最後にモジュールの多い の実装SpatialTransformerを、画像を入力する際に​​はcontext条件情報としてテキストも利用しており、両者はモデリングに Cross Attention を利用している。さらに展開してみると、実際にCross Attentionモジュールを呼び出していることがわかり、Cross Attentionモジュールでは画像情報をQuery、テキスト情報をKey&Valueとして、モデルは両者の相関関係に注目しますSpatialTransformerBasicTransformerBlock画像とテキストの各部分の内容:

ここでの Cross Attention の役割を理解するために、簡単なアイデアを使用できると思います。たとえば、トレーニング中に馬が草を食べている写真や、「砂漠で放牧されている白い馬」というテキスト プロンプトを与えて、Attention When を実行するなどです。テキスト内のキーワード「馬」は、画像内の動物 (「馬」も含む) との関連性が高くなります。これは、重量も大きいためです。一方、「馬」、「白」、「砂漠」、「草」などの重量は同じです。この時点で、モデルが十分にトレーニングされている場合、モデルは画像とテキストの一致関係を学習できるだけでなく、テキスト内のどのキーワードが画像内で強調表示されたいかを Attention main を通じて学習できます。体。

そして、「馬が草を食んでいる」と入力するなど、即発語を入力してモデルを使って画像を生成すると、この時点でモデルは画像とテキストの相関関係とテキスト内の重要な情報を捉えることができたので、 、テキスト「馬」を見ると、ブラックボックスマジックの操作の下で、画像「馬」の世代が強調表示されます;「草」を見ると、画像「草」の世代が強調表示されるので、できるだけテキストに一致する画像を生成するように。

ここまでで、UNetModel の重要な各コンポーネントの基本的な紹介が完了しました。

まとめ

UNetModel モデルの構造は複雑ではないので、基本的には絵を見ながらコードを書くことができます。さらに、各モジュールの出力結果のサイズをマークしました。これは、脳内でモデルを実行するのに非常に便利です。

この記事では、Stable Diffusion Vincent ダイアグラム コードの全体的なフレームワークを簡単に紹介し、拡散モデルに関するいくつかのコア ペーパーをリストし、UNetModel を簡単に分析します。フォローアップして、他のコア コンポーネントを分析します。

AIGCの開発が速すぎて、それを学ぶことができないことに気づきました...Zhuangziが言ったことはますます真実だと感じています:限界があれば、限界はなく、終わります!

おすすめ

転載: blog.csdn.net/Eric_1993/article/details/129393890