AIGC によるコードの観点からの DDPM (拡散モデル) の理解

最近 AIGC を学ぶ予定なので、DDPM などの基本的なネットワークから始める必要があります。この記事はコード分析の観点< a i=2> 誰もが学び、理解できるように。 DDPM (Denoising Diffusion Probabilistic Models) は拡散モデルです。

拡散モデルには、ノイズ処理ノイズ除去処理<という 2 つの主要なプロセスが含まれています。 a i=3> a>。上図に対応すると、x0からxtまでがノイズを付加する処理、xtからx0がノイズ除去の処理となります。

順方向ノイズ追加プロセスと逆方向ノイズ除去プロセスは両方ともマルコフ連鎖であり、プロセス全体には約 < a i =3>1000 ステップ。

順方向ノイズ付加処理は、入力データにノイズ (ガウス ノイズ) を連続的に付加する処理です。

逆ノイズ除去プロセスは、標準ガウス分布からノイズ サンプルを 1 つずつ徐々に取得し、最終的に生成されたサンプル データを取得します。

そのノイズ追加プロセスの式は次のとおりです: 

x_{t}=\sqrt{\alpha_{t}}x_{t-1}+\sqrt{1-\alpha _{t}}z_{1}

ここ\sqrt{\alpha_{t}}ノイズ スケジュールと呼ばれる事前に設定されたハイパーパラメータで、通常は1 の値の範囲は 0.9999 ~ 0.998 です。 [上記の式は、x_{t}x_{t-1} からどのように導出されるかを示しています。

では、x_{t}x_{t-2} の間にはどのような関係があるのでしょうか?これを前方に推定することができます (つまり、 x_{t-1} を展開します)。

x_{t}=\sqrt{\alpha _{t}}(\sqrt{\alpha_{t-1}}x_{t-2}+\sqrt{1-\alpha_{t-1}}z_{2 })+\sqrt{1-\alpha_{t}}z_{1}

毎回追加されるノイズは正規分布に従うz_{1},z_{2}...\sim N(0,1)ので、上記の式を整理すると次のことが得られます。

x_{t}=\sqrt{\alpha_{t}\alpha_{t-1}}x_{t-2}+\sqrt{1-\alpha_{t}\alpha_{t-1}}z_{2}

次に、特定のパターンが見つかり、x_{t}x_{0} の関係を取得できます。

x_{t}=\sqrt{\overline{\alpha_{t}}}x_{x0}+\sqrt{1-\overline{\alpha_{t}}}z_{t} 


DDPM はコード内で次のように定義されます。

コードは Bubbliing のコードを使用します。

net    = GaussianDiffusion(UNet(3, self.channel), self.input_shape, 3, betas=betas)

拡散モデルの入力パラメータはUNet ネットワーク、input_shape は入力サイズ (つまり、記事の冒頭)、その値は )。ベータの定義は次のとおりです (もちろん、コサインを使用してベータを生成できます。ここでは線形の例を使用します)。ここで設定されているのは 1000 以内の均一分布です ( の間、 ノイズ テーブルを生成します を指します。ベータは線形タイムテーブルであり、次の目的で使用できます。 画像入力チャネル 、3 は \アルファ_{t}schedule_lowschedule_high num_timesteps

betas = generate_linear_schedule(
                self.num_timesteps,
                self.schedule_low * 1000 / self.num_timesteps,
                self.schedule_high * 1000 / self.num_timesteps,
            )

トレーニングフォワード機能部

次に、GaussianDiffusion のコード内に入り、各コンポーネントを調べます。内部のforward関数に直接アクセスして、画像がどのように処理されるかを見てみましょう。

    def forward(self, x, y=None):
        b, c, h, w  = x.shape
        device      = x.device

        if h != self.img_size[0]:
            raise ValueError("image height does not match diffusion parameters")
        if w != self.img_size[0]:
            raise ValueError("image width does not match diffusion parameters")
        # 随机生成batch个范围在0~1000内的数
        t = torch.randint(0, self.num_timesteps, (b,), device=device)
        return self.get_losses(x, t, y)

 GaussianDiffusion の前方部分では、x が入力画像で、その後に は、ランダム生成範囲が 0~num_timesteps [タイム ステップ] バッチサイズ番号であることを示します。または、それぞれに ランダム性を与えると理解できます。バッチ (写真) タイムスタンプを付けます。次に、コードを段階的に調べて、get_losses 関数を入力します。


get_losses セクション

次は get_losses コードです。これには 3 つの入力 x、t、y があります。ここでx はトレーニング用に入力した画像ですt は上記でランダムに生成されたタイムスタンプです< a i=4>。

    def get_losses(self, x, t, y):
        # x, noise [batch_size, 3, 64, 64]
        noise           = torch.randn_like(x)  # 产生与输入图片shape一样的随机噪声(正态分布)

        perturbed_x     = self.perturb_x(x, t, noise)
        estimated_noise = self.model(perturbed_x, t, y)

        if self.loss_type == "l1":
            loss = F.l1_loss(estimated_noise, noise)
        elif self.loss_type == "l2":
            loss = F.mse_loss(estimated_noise, noise)
        return loss

関数内では、入力画像のサイズと同じ正規分布を持つランダム ノイズが最初に作成されノイズ、次に< /span> a> 障害対応。 時刻 t で入力画像にノイズを追加することですこの関数は、perturb_x

perturb_x関数部分

次に、perturb_x が時刻 t の画像にどのようにノイズを追加するかを見てみましょう (頭をすっきりさせてください。これらのコードはマトリョーシカ人形のように 1 つずつ階層化されています)。

この関数には、3 つの入力パラメータがあります。画像を入力し、この時点で t=323 とすると、次の場合に画像にノイズ t を追加すると理解できます。タイムスタンプは 323 で、このピクチャは時刻 t の入力 Xt に対応します。 sqrt_alphas_cumprodsqrt_one_minus_alphas_cumprod これら 2 つのテンソルを使用して入力画像xとノイズノイズの混合比率を時間次元で制御します。

    def perturb_x(self, x, t, noise):
        '''
        :param x:输入图像
        :param t: 每个图片不同的时间戳(范围在0~1000)
        :param noise: 与输入图片shape一样的正态分布随机噪声
        :return:经过扰动后的图像
        '''
        return (
            extract(self.sqrt_alphas_cumprod, t,  x.shape) * x +
            extract(self.sqrt_one_minus_alphas_cumprod, t, x.shape) * noise
        )

 perturb_x のプロセスを視覚化できます。たとえば、次のノイズのない元の画像があります。

 Perturb x から perturb_x までノイズを追加した後の効果:

画像に対するノイズの拡散妨害効果も制御できます。

 以上が時刻tに対応する画像Xtにノイズを付加する処理である。時間 t が経過するにつれてますますぼやけていきます

その後、get_losses 関数に戻ります (コードは次のとおりです)。perturbed_x はノイズを追加した後の時間 t における画像 Xt であり、ここでのモデルはバックボーン ネットワークです。UNet ネットワーク (UNet ネットワーク部分は別途取り出します)。次に、get_losses の主なプロセスを要約します。

ステップ 1.入力画像を perturb_x を通じて時間領域で摂動させ、それをと比較します。ランダム ノイズ noise が混合され、 生成される perturbed_x 乱れた画像 ;

ステップ 2.UNet ネットワークを通じてノイズのある画像を予測し、予測ノイズ信号estimate_noise を取得します。

ステップ 3.予測ノイズestimated_noiseと実際のノイズnoiseの損失を計算します。

    def get_losses(self, x, t, y):
        # x, noise [batch_size, 3, 64, 64]
        noise           = torch.randn_like(x)  # 产生与输入图片shape一样的随机噪声(正态分布)

        perturbed_x     = self.perturb_x(x, t, noise)
        estimated_noise = self.model(perturbed_x, t, y)

        if self.loss_type == "l1":
            loss = F.l1_loss(estimated_noise, noise)
        elif self.loss_type == "l2":
            loss = F.mse_loss(estimated_noise, noise)
        return loss

つまり、予測ノイズと実際のノイズの間の損失関係は、トレーニング フェーズ中に計算されます 

予測段階

    @torch.no_grad()
    def sample(self, batch_size, device, y=None, use_ema=True):
        if y is not None and batch_size != len(y):
            raise ValueError("sample batch size different from length of given y")

        x = torch.randn(batch_size, self.img_channels, *self.img_size, device=device)
        
        for t in tqdm(range(self.num_timesteps - 1, -1, -1), desc='remove noise....'):
            t_batch = torch.tensor([t], device=device).repeat(batch_size)
            x = self.remove_noise(x, t_batch, y, use_ema)

            if t > 0:
                x += extract(self.sigma, t_batch, x.shape) * torch.randn_like(x)
        
        return x.cpu().detach()

予測段階では入力ノイズ (ここでは入力画像ではありません) を実行します ノイズ< /span> が処理されて、最終的に生成されたイメージが取得されます。

正規分布ノイズxを入力し、連続的にノイズ除去を行います(Xt~X0の処理)。


ネットワークモデルの構造

DDP は Unet で構成されていますので、まず Unet の構成を見てみましょう。

class UNet(nn.Module):
    def __init__(
        self, img_channels, base_channels=128, channel_mults=(1, 2, 4, 8),
        num_res_blocks=3, time_emb_dim=128 * 4, time_emb_scale=1.0, num_classes=None, activation=SiLU(),
        dropout=0.1, attention_resolutions=(1,), norm="gn", num_groups=32, initial_pad=0,
    ):

time_mlp

self.time_mlp = nn.Sequential(
            PositionalEmbedding(base_channels, time_emb_scale),
            nn.Linear(base_channels, time_emb_dim),
            SiLU(),
            nn.Linear(time_emb_dim, time_emb_dim),
        ) if time_emb_dim is not None else None

time_mlp は、PositionalEmbedding 層、Linear、SiLu、Linear で構成されます。

位置埋め込みレイヤー

class PositionalEmbedding(nn.Module):
    def __init__(self, dim, scale=1.0):
        super().__init__()
        assert dim % 2 == 0
        self.dim = dim
        self.scale = scale

    def forward(self, x):
        device      = x.device
        half_dim    = self.dim // 2
        emb = math.log(10000) / half_dim
        emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
        # x * self.scale和emb外积
        emb = torch.outer(x * self.scale, emb)
        emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
        return emb

コードの前方の x は時間 (時間軸) は画像ではありません a>。

この関数は主に位置エンコードに使用されます。位置エンコードではサインとコサインを使用して位置を計算できます。使用される式は次のとおりです。

PE_{pos,2i}=sin(pos/10000^{2i/d_{モデル}})

PE_{pos,2i+1}=cos(pos/10000^{2i/d_{モデル}})

位置エンコード式では、pos はシーケンス内の各位置のインデックスを表します。長さ 4 x のシーケンスの場合、各位置には 0 から 3 のインデックスが付けられます。各位置の位置エンコーディング ベクトルを計算する際には、このインデックス値を使用して計算します。

具体的には、式内の pos はシーケンス内の位置インデックスを表し、位置エンコード ベクトルの計算中にサインとコサインの関数パラメーターを計算するために使用されます。

たとえば、位置エンコード行列の最初の位置エンコード ベクトルを計算する場合、pos の値は 0 になります。2 番目の位置エンコード ベクトルを計算する場合、< a i=2> の値は 1 などです。 pos

たとえば、次のようなシーケンスがあります。

# 设置向量的长度和位置编码的维度
vector_length = 4
embedding_dim = 4

# 生成位置编码矩阵
pos_encoding = np.zeros((vector_length, embedding_dim))

for pos in range(vector_length):
    for i in range(embedding_dim):
        if i % 2 == 0:
            pos_encoding[pos, i] = np.sin(pos / (10000 ** (2 * i / embedding_dim)))
        else:
            pos_encoding[pos, i] = np.cos(pos / (10000 ** (2 * (i - 1) / embedding_dim)))

# 打印位置编码矩阵
print(pos_encoding)

取得した位置エンコーディング行列は以下のとおりです 

[[ 0.00000000e+00 1.00000000e+00 0.00000000e+00 1.00000000e+00]
 [ 8.41470985e-01 5.40302306e-01 9.99999 998e-05 9.99999995e-01 ]
 [ 9.09297427e-01 -4.16146837e-01 1.99999999e-04 9.99999980e-01]
 [ 1.41120008e-01 -9.89992497e-01 2.9 9999995e -04 9.99999955e-01]]

このうち、配列の各行は位置エンコード行列の位置に対応し、最初の列はを表します。 この位置のサイン関数の値。2 番目の列は、この位置のコサイン関数の値を示します。の上。 

つまり、この関数では、sin と cos位置情報を取得します。

残留ブロック

class ResidualBlock(nn.Module):
    def __init__(
        self, in_channels, out_channels, dropout, time_emb_dim=None, num_classes=None, activation=SiLU(),
        norm="gn", num_groups=32, use_attention=False,
    ):
        super().__init__()

        self.activation = activation

        self.norm_1 = get_norm(norm, in_channels, num_groups)
        self.conv_1 = nn.Conv2d(in_channels, out_channels, 3, padding=1)

        self.norm_2 = get_norm(norm, out_channels, num_groups)
        self.conv_2 = nn.Sequential(
            nn.Dropout(p=dropout), 
            nn.Conv2d(out_channels, out_channels, 3, padding=1),
        )

        self.time_bias  = nn.Linear(time_emb_dim, out_channels) if time_emb_dim is not None else None
        self.class_bias = nn.Embedding(num_classes, out_channels) if num_classes is not None else None

        self.residual_connection    = nn.Conv2d(in_channels, out_channels, 1) if in_channels != out_channels else nn.Identity()
        self.attention              = nn.Identity() if not use_attention else AttentionBlock(out_channels, norm, num_groups)
    
    def forward(self, x, time_emb=None, y=None):
        out = self.activation(self.norm_1(x))
        # 第一个卷积
        out = self.conv_1(out)
        
        # 对时间time_emb做一个全连接,施加在通道上
        if self.time_bias is not None:
            if time_emb is None:
                raise ValueError("time conditioning was specified but time_emb is not passed")
            out += self.time_bias(self.activation(time_emb))[:, :, None, None]

        # 对种类y_emb做一个全连接,施加在通道上
        if self.class_bias is not None:
            if y is None:
                raise ValueError("class conditioning was specified but y is not passed")

            out += self.class_bias(y)[:, :, None, None]

        out = self.activation(self.norm_2(out))
        # 第二个卷积+残差边
        out = self.conv_2(out) + self.residual_connection(x)
        # 最后做个Attention
        out = self.attention(out)
        return out

。 。 。 。まだ更新されていません

おすすめ

転載: blog.csdn.net/z240626191s/article/details/133933052