从SD开始搞懂ControlNet

在学习SD后,我个人猜测ControlNet和SD中的Classifier Guidance类似,另外SD有一个遗憾就是没有训练,只玩了采样过程,所以在ControlNet中,再搞清楚原理之后,还会自建数据集进行训练实践(下一篇)!

一、原理介绍

原理部分我们直接拿出Controlnet作者的图

网上对这张图的说明都比较粗略,我们首先要搞明白几个值得含义,参考下面这张图,我们会发现作者在原先的SD之外新增了一个condition的额外输入,这个condition就是我们增加的额外条件,而controlnet的训练是不影响原来的SD的,即原来的SD完全保持独立和不变,影响是通过在SD的Decoder阶段添加condition的编码信息,而训练的部分就说图中的蓝色方块,也就是SD的编码块,如果算力富余的话,controlnet也可以重新训练整个SD。

思路的原理还是比较简单的,以作者的图为例

原先的a输出是,x是输入的二维图像,\theta是neural network block的参数 

添加了Controlnet之后的b输出,Z(;Θz)指的是zero convolution,也就是一个3x3的卷积层,并将权重和偏差都初始化为0 也就是最开始的

Controlnet的损失函数也比较简单,在原先SD的基础之上增加了新的Cf

 二、代码分析

在代码部分,有了SD的基础,我们直接来看Controlnet的forward部分,我们可以看到除了经典的x,timesteps,context三件套输入外,还多了hint

    def forward(self, x, hint, timesteps, context, **kwargs)

这个地方有了前面的基础,我们很容易就能找到调用模型的地方从而知道hint的来源

control = self.control_model(x=x_noisy, hint=torch.cat(cond['c_concat'], 1), timesteps=t, context=cond_txt)

 可以看到这里的cond和SD发生了变化,以scribble2img为例,其来源如下,也就是读取输入的图像以后,做一定的图像处理,就可以直接当作c_concat输入进模型了

        img = resize_image(HWC3(input_image), image_resolution)
        H, W, C = img.shape

        detected_map = np.zeros_like(img, dtype=np.uint8)
        detected_map[np.min(img, axis=2) < 127] = 255//这步是根据图像颜色取反,可以调出图片看看,比较抽象,最后的输入是黑底,然后白边是原输入的样子

        control = torch.from_numpy(detected_map.copy()).float().cuda() / 255.0
        control = torch.stack([control for _ in range(num_samples)], dim=0)
        control = einops.rearrange(control, 'b h w c -> b c h w').clone()

        cond = {"c_concat": [control], "c_crossattn": [model.get_learned_conditioning([prompt + ', ' + a_prompt] * num_samples)]}
        un_cond = {"c_concat": None if guess_mode else [control], "c_crossattn": [model.get_learned_conditioning([n_prompt] * num_samples)]}

知道了额外条件是如何输入的,我们再返回forward看其是如何被处理的,forward的前两步是对时间步编码,这个在SD中已经熟悉了,接着对将三个条件输入hint、emb、context都输入进了input_hint_block,我们来看一下这是个什么结构。

    def forward(self, x, hint, timesteps, context, **kwargs):
        t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
        emb = self.time_embed(t_emb)

        guided_hint = self.input_hint_block(hint, emb, context)

这里回顾一下TimestepEmbedSequential,其主要的作用其实就是分辨输入是否要和emb以及context结合,对于input_hint_block其主要作用是对于hint进行编码,所以输入其实只有hint一个,然后就是一系列的3*3卷积进行特征处理,值得注意的是最后一个卷积层加了一个zero_module,作用就是将该层的参数置0,这样做的目的也很明显,一方面是调整hint的通道数,使其可以被输入进Unet,另一方面也是做特征提取,和SD的输入一样,被置入隐式表达,可以减少显存占用,加快推理速度

        self.input_hint_block = TimestepEmbedSequential(
            conv_nd(dims, hint_channels, 16, 3, padding=1),
            nn.SiLU(),
            conv_nd(dims, 16, 16, 3, padding=1),
            nn.SiLU(),
            conv_nd(dims, 16, 32, 3, padding=1, stride=2),
            nn.SiLU(),
            conv_nd(dims, 32, 32, 3, padding=1),
            nn.SiLU(),
            conv_nd(dims, 32, 96, 3, padding=1, stride=2),
            nn.SiLU(),
            conv_nd(dims, 96, 96, 3, padding=1),
            nn.SiLU(),
            conv_nd(dims, 96, 256, 3, padding=1, stride=2),
            nn.SiLU(),
            zero_module(conv_nd(dims, 256, model_channels, 3, padding=1))
        )

def zero_module(module):
    """
    Zero out the parameters of a module and return it.
    """
    for p in module.parameters():
        p.detach().zero_()
    return module


class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
    """
    A sequential module that passes timestep embeddings to the children that
    support it as an extra input.
    """

    def forward(self, x, emb, context=None):
        for layer in self:
            if isinstance(layer, TimestepBlock):
                x = layer(x, emb)
            elif isinstance(layer, SpatialTransformer):
                x = layer(x, context)
            else:
                x = layer(x)
        return x

后面的步骤和SD大同小异,首先module也就是SD中的input_blocks其由resnetblock和spatialtransformer组成,最后输出的就是噪音了,但是在这里噪音还要加速之前的guided_hint,我们来看一下这里新出现的zero_conv,其实就是一个置0的卷积层,而且是1*1的conv2d,这个过程虽然输入了emb和context,但是其都是不作用的,只对h做作用,那么此时的h就是guided_hint+原来SDUnet输出的噪音

        outs = []

        h = x.type(self.dtype)
        for module, zero_conv in zip(self.input_blocks, self.zero_convs):
            if guided_hint is not None:
                h = module(h, emb, context)
                h += guided_hint
                guided_hint = None
            else:
                h = module(h, emb, context)
            outs.append(zero_conv(h, emb, context))


        self.zero_convs = nn.ModuleList([self.make_zero_conv(model_channels)])

    def make_zero_conv(self, channels):
        return TimestepEmbedSequential(zero_module(conv_nd(self.dims, channels, channels, 1, padding=0)))

 在降采样并提取特征之后,是middle_block,其结构还是比较清晰的,最后就直接返回outs了。可见其推理过程是相当简单的,相比于SD核心就是增加了新的输入hint和其编码部分input_hint_block

        h = self.middle_block(h, emb, context)
        outs.append(self.middle_block_out(h, emb, context))
        return outs

        self.middle_block_out = self.make_zero_conv(ch)

        self.middle_block = TimestepEmbedSequential(
            ResBlock(
                ch,
                time_embed_dim,
                dropout,
                dims=dims,
                use_checkpoint=use_checkpoint,
                use_scale_shift_norm=use_scale_shift_norm,
            ),
            AttentionBlock(
                ch,
                use_checkpoint=use_checkpoint,
                num_heads=num_heads,
                num_head_channels=dim_head,
                use_new_attention_order=use_new_attention_order,
            ) if not use_spatial_transformer else SpatialTransformer(  # always uses a self-attn
                ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,
                disable_self_attn=disable_middle_self_attn, use_linear=use_linear_in_transformer,
                use_checkpoint=use_checkpoint
            ),
            ResBlock(
                ch,
                time_embed_dim,
                dropout,
                dims=dims,
                use_checkpoint=use_checkpoint,
                use_scale_shift_norm=use_scale_shift_norm,
            ),
        )

猜你喜欢

转载自blog.csdn.net/fisherisfish/article/details/132576677
sd