IDDPM コード ResBlock と TimestepEmbedSequential の解釈

IDDPM コード ResBlock と TimestepEmbedSequential の解釈

ResBlock の forward と _forward の違い

# ResBlock是为了把embedding以残差的形式和图片加起来,即把时间信息融合到图片中去
class ResBlock(TimestepBlock): 
    # resblock是继承自timestepblock的,所以所有的resblock部分肯定是要传入embedding的
    # 而在attention, 上采样,下采样都不需要传入embedding
    """
    A residual block that can optionally change the number of channels.

    :param channels: the number of input channels.
    :param emb_channels: the number of timestep embedding channels.
    :param dropout: the rate of dropout.
    :param out_channels: if specified, the number of out channels.
    :param use_conv: if True and out_channels is specified, use a spatial
        convolution instead of a smaller 1x1 convolution to change the
        channels in the skip connection.
    :param dims: determines if the signal is 1D, 2D, or 3D.
    :param use_checkpoint: if True, use gradient checkpointing on this module.
    """

    def __init__(
        self,
        channels,
        emb_channels,
        dropout,
        out_channels=None,
        use_conv=False,
        use_scale_shift_norm=False,
        dims=2,
        use_checkpoint=False,
    ):
        super().__init__()
        self.channels = channels
        self.emb_channels = emb_channels
        self.dropout = dropout
        self.out_channels = out_channels or channels
        self.use_conv = use_conv
        self.use_checkpoint = use_checkpoint
        self.use_scale_shift_norm = use_scale_shift_norm

        self.in_layers = nn.Sequential(
            normalization(channels),
            SiLU(),
            conv_nd(dims, channels, self.out_channels, 3, padding=1),
        )
        self.emb_layers = nn.Sequential(
            SiLU(),
            linear(
                emb_channels,
                2 * self.out_channels if use_scale_shift_norm else self.out_channels,
            ),
        )
        self.out_layers = nn.Sequential(
            normalization(self.out_channels),
            SiLU(),
            nn.Dropout(p=dropout),
            zero_module(
                conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)
            ),
        )

        if self.out_channels == channels:  # 如果通道数目一致的话,直接连起来就好
            self.skip_connection = nn.Identity()
        elif use_conv:# 如果通道数目不一致的话,可以用一个大小不变的卷积去做
            self.skip_connection = conv_nd(
                dims, channels, self.out_channels, 3, padding=1
            )
        else:# 在有的论文中,如果通道数目不一致的话,也可以用一个1*1的卷积去做逐点的卷积
            self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)

    def forward(self, x, emb):
        """
        Apply the block to a Tensor, conditioned on a timestep embedding.

        :param x: an [N x C x ...] Tensor of features.
        :param emb: an [N x emb_channels] Tensor of timestep embeddings.
        :return: an [N x C x ...] Tensor of outputs.
        """
        return checkpoint(
            self._forward, (x, emb), self.parameters(), self.use_checkpoint
        )

    def _forward(self, x, emb): # _forward 是私有方法,它执行实际的计算并将其结果返回给 forward
        h = self.in_layers(x)
        emb_out = self.emb_layers(emb).type(h.dtype)
        while len(emb_out.shape) < len(h.shape):
            emb_out = emb_out[..., None]
        if self.use_scale_shift_norm:
            out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
            scale, shift = th.chunk(emb_out, 2, dim=1)
            h = out_norm(h) * (1 + scale) + shift # ys = (1+scale), yb = bias
            h = out_rest(h)
        else:
            h = h + emb_out
            h = self.out_layers(h)
        return self.skip_connection(x) + h    # identity

このクラスでは、 forward は、このモジュールのサブモジュールに入力を渡し、結果を返すためのパブリック インターフェイスです。_forward は、実際の計算を実行し、その結果を forward に返すプライベート メソッドです。このクラスでは、forward メソッドがチェックポイントを呼び出して、PyTorch の自動微分メカニズムを利用してメモリ使用量を削減していることがわかります。次に、_forward メソッドはすべての計算を実行し、最終結果を forward に返します。したがって、forward メソッドは ResBlock クラスの外部インターフェイスであり、_forward メソッドはその内部実装であると言えます。

タイムステップ埋め込みシーケンシャル

class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
    """
    A sequential module that passes timestep embeddings to the children that
    support it as an extra input.
    """
    # emb:timestep embedding和condition embedding混合起来的
    def forward(self, x, emb):
        for layer in self:  
            if isinstance(layer, TimestepBlock):# 只有layer是timestepblock的时候才输入emb
                x = layer(x, emb)
            else:
                x = layer(x)
        return x

その呼び出しステートメントは

TimestepEmbedSequential(
                   conv_nd(dims, in_channels, model_channels, 3, padding=1)
                ) conv_nd的结果是怎么传入TimestepEmbedSequential的?

TimestepEmbedSequential を呼び出すときは、 conv_nd(dims, in_channels, model_channels, 3, padding=1) によって作成された nn.Module オブジェクトを渡します。conv_nd は、指定されたパラメーターから畳み込み層オブジェクトを作成するファクトリ関数です。ここで、conv_nd は dims の次元 (dims は、1 次元畳み込みの場合は 1、2 次元畳み込みの場合は 2、3 次元畳み込みの場合は 3 など、畳み込み層の次元を表す整数です)、入力チャネルの数を返します。は in_channels で、出力チャネル番号は model_channels で、コンボリューション カーネル サイズ 3 のコンボリューション層オブジェクトがパラメーターとして TimestepEmbedSequential のコンストラクターに渡されます。

TimestepEmbedSequential の forward メソッドでは、畳み込み層オブジェクトが nn.Sequential のサブモジュールとして使用されます。つまり、セルフ リストに追加されます。forward メソッドを呼び出すと、入力テンソル x とタイムステップ埋め込みテンソル emb が nn.Sequential の各サブモジュールに順番に渡されます。サブモジュールが TimestepBlock タイプの場合、emb は追加の入力としてサブモジュールの forward メソッドに渡されます。この例では、conv_nd は TimestepBlock タイプではないため、emb は無視され、x のみがレイヤーの forward メソッドに渡されます。

おすすめ

転載: blog.csdn.net/weixin_43845922/article/details/129937987