Diffusion diffusion model learning 2——Stable Diffusion structure analysis—taking text to generate images as an example

study preface

I have used Stable Diffusion for a long time, but I have never analyzed its internal structure properly. Write a blog to record it, hehe.
insert image description here

Source code download address

https://github.com/bubbliiiing/stable-diffusion

You can order a star if you like it.

network construction

1. What is Stable Diffusion (SD)

Stable Diffusion is a relatively new diffusion model. The translation is stable diffusion. Although the name is stable diffusion, in fact, the results generated by changing the seed are completely different and very unstable.

The initial application of Stable Diffusion should be to generate images from text, that is, Vincent graphs. With the development of technology, Stable Diffusion not only supports the generation of image2image graphs, but also supports various control methods such as ControlNet to customize the generated images.

Stable Diffusion is based on the diffusion model, so it inevitably includes the process of continuous denoising. If it is a picture, there is also a process of continuous noise addition. At this time, the old picture of DDPM is inseparable, as follows: Compared with DDPM, Stable
insert image description here
Diffusion The DDIM sampler is used, the diffusion of the latent space is used, and the very large LAION-5B dataset is used for pre-training.

Direct Finetune Stable Diffusion Most students should not be able to cover the cost, but Stable Diffusion has many lightweight Finetune solutions, such as Lora, Textual Inversion, etc., but this is a later story.

This article mainly analyzes the structural composition of the entire SD model, the process of one diffusion and multiple diffusions.

Large models and AIGC are the current industry trends, if you don't know how to do it, you will be eliminated easily, hh.

2. Composition of Stable Diffusion

Stable Diffusion consists of four parts.
1. Sampler sampler.
2. Variational Autoencoder (VAE) Variational Autoencoder.
3. UNet main network, noise predictor.
4. CLIPE mbedder text encoder.

Every part is very important, we first take text to generate images as an example to analyze. Since the image is generated from text, our input is only text, and there is no input image at this time.

3. Generation process

insert image description here
The generation process is divided into three parts:
1. Prompt text encoding.
2. Take several samples.
3. Decode.

with torch.no_grad():
    if seed == -1:
        seed = random.randint(0, 65535)
    seed_everything(seed)

    # ----------------------- #
    #   获得编码后的prompt
    # ----------------------- #
    cond    = {
    
    "c_crossattn": [model.get_learned_conditioning([prompt + ', ' + a_prompt] * num_samples)]}
    un_cond = {
    
    "c_crossattn": [model.get_learned_conditioning([n_prompt] * num_samples)]}
    H, W    = input_shape
    shape   = (4, H // 8, W // 8)

    # ----------------------- #
    #   进行采样
    # ----------------------- #
    samples, intermediates = ddim_sampler.sample(ddim_steps, num_samples,
                                                    shape, cond, verbose=False, eta=eta,
                                                    unconditional_guidance_scale=scale,
                                                    unconditional_conditioning=un_cond)

    # ----------------------- #
    #   进行解码
    # ----------------------- #
    x_samples = model.decode_first_stage(samples)
    x_samples = (einops.rearrange(x_samples, 'b c h w -> b h w c') * 127.5 + 127.5).cpu().numpy().clip(0, 255).astype(np.uint8)

1. Text encoding

insert image description here
The idea of ​​text encoding is relatively simple, just use the CLIP text encoder to encode directly. A FrozenCLIPEmbedder category is defined in the code, and the CLIPTokenizer and CLIPTextModel of the transformers library are used.

In the pre-transfer process, we first use CLIPTokenizer to encode the input text, and then use CLIPTextModel for feature extraction. Through FrozenCLIPEmbedder, we can obtain a feature vector of [batch_size, 77, 768].

class FrozenCLIPEmbedder(AbstractEncoder):
    """Uses the CLIP transformer encoder for text (from huggingface)"""
    LAYERS = [
        "last",
        "pooled",
        "hidden"
    ]
    def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77,
                 freeze=True, layer="last", layer_idx=None):  # clip-vit-base-patch32
        super().__init__()
        assert layer in self.LAYERS
        # 定义文本的tokenizer和transformer
        self.tokenizer      = CLIPTokenizer.from_pretrained(version)
        self.transformer    = CLIPTextModel.from_pretrained(version)
        self.device         = device
        self.max_length     = max_length
        # 冻结模型参数
        if freeze:
            self.freeze()
        self.layer = layer
        self.layer_idx = layer_idx
        if layer == "hidden":
            assert layer_idx is not None
            assert 0 <= abs(layer_idx) <= 12

    def freeze(self):
        self.transformer = self.transformer.eval()
        # self.train = disabled_train
        for param in self.parameters():
            param.requires_grad = False

    def forward(self, text):
        # 对输入的图片进行分词并编码,padding直接padding到77的长度。
        batch_encoding  = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,
                                        return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
        # 拿出input_ids然后传入transformer进行特征提取。
        tokens          = batch_encoding["input_ids"].to(self.device)
        outputs         = self.transformer(input_ids=tokens, output_hidden_states=self.layer=="hidden")
        # 取出所有的token
        if self.layer == "last":
            z = outputs.last_hidden_state
        elif self.layer == "pooled":
            z = outputs.pooler_output[:, None, :]
        else:
            z = outputs.hidden_states[self.layer_idx]
        return z

    def encode(self, text):
        return self(text)

2. Sampling process

insert image description here

a. Generate initial noise

Since there is only text in the input and no input image, where does the initial noise come from?

It is enough to directly create a normal distribution of noise here. The simple understanding is: since the original image is continuously added with normal distribution noise during training to obtain the final noise matrix, then I directly initialize a normal distribution of noise It is reasonable to generate pictures as initial noise .

This is actually done in the code, but because we diffuse in the hidden space , the noise we generate is also relative to the hidden space .

Here is a brief introduction to VAE. VAE is a variational autoencoder, which can encode input pictures. A picture with a height and width of 512x512x3 will become 64x64x4 after being encoded with VAE. This 4 is artificially set, so don’t worry about it why not 3 . At this time, we use a simple matrix to replace the original 512x512x3 picture, and the transmission and storage costs are very low. When you actually want to see it, you can decode the 64x64x4 matrix to get a 512x512x3 picture.

Therefore, if the noise we generate is relative to the hidden space , and we want to generate a 512x512x3 picture, then we need to initialize a 64x64x4 hidden vector. After we diffuse the latent space, we can use the decoder to generate 512x512x3 Image.

In code, we do exactly that, the initial noise generation code is:

img = torch.randn(shape, device=device)

The code is in the ddim_sampling method in ldm.models.diffusion.ddim.py. The shape is passed in from the outside, and the size is [4, 64, 64].
insert image description here

b. Sampling the noise N times

Since Stable Diffusion is a process of continuous diffusion, continuous denoising is indispensable, so how to denoise is a problem.

In the previous step, we have obtained an img, which is a vector that conforms to a normal distribution, and we start to denoise from it.

We will invert the time step of ddim_timesteps, because we are now denoising instead of adding noise, and then perform a loop on it. The code for the loop is as follows:

There is a mask in the loop, which is used for local reconstruction and masks the hidden vectors of some areas, which is not used here . Everything else is a method or a function, and nothing can be seen. The one that looks most like the sampling process is the p_sample_ddim method, we need to enter the p_sample_ddim method to see.

for i, step in enumerate(iterator):
    # index是用来取得对应的调节参数的
    index   = total_steps - i - 1
    # 将步数拓展到bs维度
    ts      = torch.full((b,), step, device=device, dtype=torch.long)

    # 用于进行局部的重建,对部分区域的隐向量进行mask。
    if mask is not None:
        assert x0 is not None
        img_orig = self.model.q_sample(x0, ts)  # TODO: deterministic forward pass?
        img = img_orig * mask + (1. - mask) * img

    # 进行采样
    outs = self.p_sample_ddim(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps,
                                quantize_denoised=quantize_denoised, temperature=temperature,
                                noise_dropout=noise_dropout, score_corrector=score_corrector,
                                corrector_kwargs=corrector_kwargs,
                                unconditional_guidance_scale=unconditional_guidance_scale,
                                unconditional_conditioning=unconditional_conditioning)
    img, pred_x0 = outs
    # 回调函数
    if callback: callback(i)
    if img_callback: img_callback(pred_x0, i)

    if index % log_every_t == 0 or index == total_steps - 1:
        intermediates['x_inter'].append(img)
        intermediates['pred_x0'].append(pred_x0)

insert image description here

c. Single sampling analysis

I. Prediction noise

Before word sampling, we need to first judge whether there is a neg prompt. If so, we need to process the neg prompt at the same time, otherwise we only need to process the pos prompt. In actual use, there is generally a neg prompt (the effect will be better), so the corresponding processing process is entered by default.

When dealing with neg prompt, we copy the input hidden vector and step number, one belongs to pos prompt and the other belongs to neg prompt. The default stacking dimension of torch.cat is 0, so it is stacked in the batch_size dimension, and the two will not affect each other. Then we stack the pos prompt and neg prompt into a batch, which is also stacked in the batch_size dimension.

# 首先判断是否由neg prompt,unconditional_conditioning是由neg prompt获得的
if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
    e_t = self.model.apply_model(x, t, c)
else:
    # 一般都是有neg prompt的,所以进入到这里
    # 在这里我们对隐向量和步数进行复制,一个属于pos prompt,一个属于neg prompt
    # torch.cat默认堆叠维度为0,所以是在bs维度进行堆叠,二者不会互相影响
    x_in = torch.cat([x] * 2)
    t_in = torch.cat([t] * 2)
    # 然后我们将pos prompt和neg prompt堆叠到一个batch中
    if isinstance(c, dict):
        assert isinstance(unconditional_conditioning, dict)
        c_in = dict()
        for k in c:
            if isinstance(c[k], list):
                c_in[k] = [
                    torch.cat([unconditional_conditioning[k][i], c[k][i]])
                    for i in range(len(c[k]))
                ]
            else:
                c_in[k] = torch.cat([unconditional_conditioning[k], c[k]])
    else:
        c_in = torch.cat([unconditional_conditioning, c])

insert image description here
After stacking, we pass the hidden vector, step number and prompt condition into the network together, and divide the result in the bs dimension using chunks.

Because when we are stacking, the neg prompt is placed in front. Therefore, after the division, the first half e_t_uncondis obtained by using the neg prompt, and the second half e_tis obtained by using the pos prompt. In essence, we should expand the influence of the pos prompt and stay away from the influence of the neg prompt . Therefore, we use e_t-e_t_uncondto calculate the distance between the two, and use scale to expand the distance between the two. On the basis of e_t_uncond, the final hidden vector is obtained.

# 堆叠完后,隐向量、步数和prompt条件一起传入网络中,将结果在bs维度进行使用chunk进行分割
e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)

insert image description here
The e_t obtained at this time is the prediction noise obtained jointly by the hidden vector and prompt.

II. Applying noise

Is it OK to get the noise? Obviously not, we also need to add the new noise obtained to the original original noise in a certain proportion.

In this place, we'd better combine the formula in ddim, we need to get α ˉ t \bar{\alpha}_taˉt α ˉ t − 1 \bar{\alpha}_{t-1} aˉt1 σ t \sigma_t pt 1 − α ˉ t \sqrt{1-\bar{\alpha}_t} 1aˉt .
insert image description here
insert image description here
In the code, we have actually pre-calculated these parameters. We just need to take it out directly, the a_t below is the α ˉ t \bar{\alpha}_t outside the brackets in the formulaaˉt, a_prev is α ˉ t − 1 \bar{\alpha}_{t-1} in the formulaaˉt1, sigma_t is the σ t \sigma_t in the formulapt, sqrt_one_minus_at is 1 − α ˉ t \sqrt{1-\bar{\alpha}_t} in the formula1aˉt

# 根据采样器选择参数
alphas      = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
sigmas      = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas

# 根据步数选择参数,
# 这里的index就是上面循环中的total_steps - i - 1
a_t         = torch.full((b, 1, 1, 1), alphas[index], device=device)
a_prev      = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
sigma_t     = torch.full((b, 1, 1, 1), sigmas[index], device=device)
sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)

In fact, in this step, we just took out all the coefficients that the formula needs to use to facilitate subsequent addition, subtraction, multiplication and division. Then we implement the above formula in code.

# current prediction for x_0
# 公式中的最左边
pred_x0             = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
if quantize_denoised:
    pred_x0, _, *_  = self.model.first_stage_model.quantize(pred_x0)
# direction pointing to x_t
# 公式的中间
dir_xt              = (1. - a_prev - sigma_t**2).sqrt() * e_t
# 公式最右边
noise               = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
if noise_dropout > 0.:
    noise           = torch.nn.functional.dropout(noise, p=noise_dropout)
x_prev              = a_prev.sqrt() * pred_x0 + dir_xt + noise
# 输出添加完公式的结果
return x_prev, pred_x0

insert image description here

d. Network structure analysis in the process of predicting noise

I. Analysis of apply_model method

In the process of predicting noise in 3.a, we used the model.apply_model method to predict noise. What this method does is hidden. Let's take a look at the specific work.

The apply_model method is in the ldm.models.diffusion.ddpm.py file. In apply_model, we pass x_noisy into self.model to predict the noise.

x_recon = self.model(x_noisy, t, **cond)

insert image description here
self.model is a pre-built class defined in line 1416 of the ldm.models.diffusion.ddpm.py file, which contains the Unet network of Stable Diffusion. The function of self.model is somewhat similar to the wrapper, which is selected according to the model The feature fusion method is used to fuse the text and the noise generated above.

c_concat represents fusion using stacking, and c_crossattn represents fusion using attention.

class DiffusionWrapper(pl.LightningModule):
    def __init__(self, diff_model_config, conditioning_key):
        super().__init__()
        self.sequential_cross_attn = diff_model_config.pop("sequential_crossattn", False)
        # stable diffusion的unet网络
        self.diffusion_model = instantiate_from_config(diff_model_config)
        self.conditioning_key = conditioning_key
        assert self.conditioning_key in [None, 'concat', 'crossattn', 'hybrid', 'adm', 'hybrid-adm', 'crossattn-adm']

    def forward(self, x, t, c_concat: list = None, c_crossattn: list = None, c_adm=None):
        if self.conditioning_key is None:
            out = self.diffusion_model(x, t)
        elif self.conditioning_key == 'concat':
            xc = torch.cat([x] + c_concat, dim=1)
            out = self.diffusion_model(xc, t)
        elif self.conditioning_key == 'crossattn':
            if not self.sequential_cross_attn:
                cc = torch.cat(c_crossattn, 1)
            else:
                cc = c_crossattn
            out = self.diffusion_model(x, t, context=cc)
        elif self.conditioning_key == 'hybrid':
            xc = torch.cat([x] + c_concat, dim=1)
            cc = torch.cat(c_crossattn, 1)
            out = self.diffusion_model(xc, t, context=cc)
        elif self.conditioning_key == 'hybrid-adm':
            assert c_adm is not None
            xc = torch.cat([x] + c_concat, dim=1)
            cc = torch.cat(c_crossattn, 1)
            out = self.diffusion_model(xc, t, context=cc, y=c_adm)
        elif self.conditioning_key == 'crossattn-adm':
            assert c_adm is not None
            cc = torch.cat(c_crossattn, 1)
            out = self.diffusion_model(x, t, context=cc, y=c_adm)
        elif self.conditioning_key == 'adm':
            cc = c_crossattn[0]
            out = self.diffusion_model(x, t, y=cc)
        else:
            raise NotImplementedError()

        return out

insert image description here
The self.diffusion_model in the code is the Unet network of Stable Diffusion, and the network structure is located in the UNetModel class in the ldm.modules.diffusionmodules.openaimodel.py file.

II. UNetModel model analysis

The main work of UNetModel is to combine time step t and text Embedding to calculate the noise at this moment . Although the idea of ​​UNet is very simple, in StableDiffusion, UNetModel is composed of ResBlock and Transformer modules, which is more complicated than ordinary UNet as a whole.

Prompt obtains Text Embedding through Frozen CLIP Text Encoder, and Timesteps obtains Timesteps Embedding through full connection (MLP);

ResBlock is used to combine time step Timesteps Embedding, and Transformer module is used to combine Text Embedding.

I put a big picture here, and students can see the changes in the internal shape.
Please add a picture description

The Unet code looks like this:

class UNetModel(nn.Module):
    """
    The full UNet model with attention and timestep embedding.
    :param in_channels: channels in the input Tensor.
    :param model_channels: base channel count for the model.
    :param out_channels: channels in the output Tensor.
    :param num_res_blocks: number of residual blocks per downsample.
    :param attention_resolutions: a collection of downsample rates at which
        attention will take place. May be a set, list, or tuple.
        For example, if this contains 4, then at 4x downsampling, attention
        will be used.
    :param dropout: the dropout probability.
    :param channel_mult: channel multiplier for each level of the UNet.
    :param conv_resample: if True, use learned convolutions for upsampling and
        downsampling.
    :param dims: determines if the signal is 1D, 2D, or 3D.
    :param num_classes: if specified (as an int), then this model will be
        class-conditional with `num_classes` classes.
    :param use_checkpoint: use gradient checkpointing to reduce memory usage.
    :param num_heads: the number of attention heads in each attention layer.
    :param num_heads_channels: if specified, ignore num_heads and instead use
                               a fixed channel width per attention head.
    :param num_heads_upsample: works with num_heads to set a different number
                               of heads for upsampling. Deprecated.
    :param use_scale_shift_norm: use a FiLM-like conditioning mechanism.
    :param resblock_updown: use residual blocks for up/downsampling.
    :param use_new_attention_order: use a different attention pattern for potentially
                                    increased efficiency.
    """

    def __init__(
        self,
        image_size,
        in_channels,
        model_channels,
        out_channels,
        num_res_blocks,
        attention_resolutions,
        dropout=0,
        channel_mult=(1, 2, 4, 8),
        conv_resample=True,
        dims=2,
        num_classes=None,
        use_checkpoint=False,
        use_fp16=False,
        num_heads=-1,
        num_head_channels=-1,
        num_heads_upsample=-1,
        use_scale_shift_norm=False,
        resblock_updown=False,
        use_new_attention_order=False,
        use_spatial_transformer=False,    # custom transformer support
        transformer_depth=1,              # custom transformer support
        context_dim=None,                 # custom transformer support
        n_embed=None,                     # custom support for prediction of discrete ids into codebook of first stage vq model
        legacy=True,
    ):
        super().__init__()
        if use_spatial_transformer:
            assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...'

        if context_dim is not None:
            assert use_spatial_transformer, 'Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...'
            from omegaconf.listconfig import ListConfig
            if type(context_dim) == ListConfig:
                context_dim = list(context_dim)

        if num_heads_upsample == -1:
            num_heads_upsample = num_heads

        if num_heads == -1:
            assert num_head_channels != -1, 'Either num_heads or num_head_channels has to be set'

        if num_head_channels == -1:
            assert num_heads != -1, 'Either num_heads or num_head_channels has to be set'

        self.image_size = image_size
        self.in_channels = in_channels
        self.model_channels = model_channels
        self.out_channels = out_channels
        self.num_res_blocks = num_res_blocks
        self.attention_resolutions = attention_resolutions
        self.dropout = dropout
        self.channel_mult = channel_mult
        self.conv_resample = conv_resample
        self.num_classes = num_classes
        self.use_checkpoint = use_checkpoint
        self.dtype = th.float16 if use_fp16 else th.float32
        self.num_heads = num_heads
        self.num_head_channels = num_head_channels
        self.num_heads_upsample = num_heads_upsample
        self.predict_codebook_ids = n_embed is not None

        # 用于计算当前采样时间t的embedding
        time_embed_dim  = model_channels * 4
        self.time_embed = nn.Sequential(
            linear(model_channels, time_embed_dim),
            nn.SiLU(),
            linear(time_embed_dim, time_embed_dim),
        )

        if self.num_classes is not None:
            self.label_emb = nn.Embedding(num_classes, time_embed_dim)
        
        # 定义输入模块的第一个卷积
        # TimestepEmbedSequential也可以看作一个包装器,根据层的种类进行时间或者文本的融合。
        self.input_blocks = nn.ModuleList(
            [
                TimestepEmbedSequential(
                    conv_nd(dims, in_channels, model_channels, 3, padding=1)
                )
            ]
        )
        self._feature_size  = model_channels
        input_block_chans   = [model_channels]
        ch                  = model_channels
        ds                  = 1
        # 对channel_mult进行循环,channel_mult一共有四个值,代表unet四个部分通道的扩张比例
        # [1, 2, 4, 4]
        for level, mult in enumerate(channel_mult):
            # 每个部分循环两次
            # 添加一个ResBlock和一个AttentionBlock
            for _ in range(num_res_blocks):
                # 先添加一个ResBlock
                # 用于对输入的噪声进行通道数的调整,并且融合t的特征
                layers = [
                    ResBlock(
                        ch,
                        time_embed_dim,
                        dropout,
                        out_channels=mult * model_channels,
                        dims=dims,
                        use_checkpoint=use_checkpoint,
                        use_scale_shift_norm=use_scale_shift_norm,
                    )
                ]
                # ch便是上述ResBlock的输出通道数
                ch = mult * model_channels
                if ds in attention_resolutions:
                    # num_heads=8
                    if num_head_channels == -1:
                        dim_head = ch // num_heads
                    else:
                        num_heads = ch // num_head_channels
                        dim_head = num_head_channels
                    if legacy:
                        #num_heads = 1
                        dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
                    # 使用了SpatialTransformer自注意力,加强全局特征,融合文本的特征
                    layers.append(
                        AttentionBlock(
                            ch,
                            use_checkpoint=use_checkpoint,
                            num_heads=num_heads,
                            num_head_channels=dim_head,
                            use_new_attention_order=use_new_attention_order,
                        ) if not use_spatial_transformer else SpatialTransformer(
                            ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim
                        )
                    )
                self.input_blocks.append(TimestepEmbedSequential(*layers))
                self._feature_size += ch
                input_block_chans.append(ch)
            # 如果不是四个部分中的最后一个部分,那么都要进行下采样。
            if level != len(channel_mult) - 1:
                out_ch = ch
                # 在此处进行下采样
                # 一般直接使用Downsample模块
                self.input_blocks.append(
                    TimestepEmbedSequential(
                        ResBlock(
                            ch,
                            time_embed_dim,
                            dropout,
                            out_channels=out_ch,
                            dims=dims,
                            use_checkpoint=use_checkpoint,
                            use_scale_shift_norm=use_scale_shift_norm,
                            down=True,
                        )
                        if resblock_updown
                        else Downsample(
                            ch, conv_resample, dims=dims, out_channels=out_ch
                        )
                    )
                )
                # 为下一阶段定义参数。
                ch = out_ch
                input_block_chans.append(ch)
                ds *= 2
                self._feature_size += ch

        if num_head_channels == -1:
            dim_head = ch // num_heads
        else:
            num_heads = ch // num_head_channels
            dim_head = num_head_channels
        if legacy:
            #num_heads = 1
            dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
        # 定义中间层
        # ResBlock + SpatialTransformer + ResBlock
        self.middle_block = TimestepEmbedSequential(
            ResBlock(
                ch,
                time_embed_dim,
                dropout,
                dims=dims,
                use_checkpoint=use_checkpoint,
                use_scale_shift_norm=use_scale_shift_norm,
            ),
            AttentionBlock(
                ch,
                use_checkpoint=use_checkpoint,
                num_heads=num_heads,
                num_head_channels=dim_head,
                use_new_attention_order=use_new_attention_order,
            ) if not use_spatial_transformer else SpatialTransformer(
                            ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim
                        ),
            ResBlock(
                ch,
                time_embed_dim,
                dropout,
                dims=dims,
                use_checkpoint=use_checkpoint,
                use_scale_shift_norm=use_scale_shift_norm,
            ),
        )
        self._feature_size += ch

        # 定义Unet上采样过程
        self.output_blocks = nn.ModuleList([])
        # 循环把channel_mult反了过来
        for level, mult in list(enumerate(channel_mult))[::-1]:
            # 上采样时每个部分循环三次
            for i in range(num_res_blocks + 1):
                ich = input_block_chans.pop()
                # 首先添加ResBlock层
                layers = [
                    ResBlock(
                        ch + ich,
                        time_embed_dim,
                        dropout,
                        out_channels=model_channels * mult,
                        dims=dims,
                        use_checkpoint=use_checkpoint,
                        use_scale_shift_norm=use_scale_shift_norm,
                    )
                ]
                ch = model_channels * mult
                # 然后进行SpatialTransformer自注意力
                if ds in attention_resolutions:
                    if num_head_channels == -1:
                        dim_head = ch // num_heads
                    else:
                        num_heads = ch // num_head_channels
                        dim_head = num_head_channels
                    if legacy:
                        #num_heads = 1
                        dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
                    layers.append(
                        AttentionBlock(
                            ch,
                            use_checkpoint=use_checkpoint,
                            num_heads=num_heads_upsample,
                            num_head_channels=dim_head,
                            use_new_attention_order=use_new_attention_order,
                        ) if not use_spatial_transformer else SpatialTransformer(
                            ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim
                        )
                    )
                # 如果不是channel_mult循环的第一个
                # 且
                # 是num_res_blocks循环的最后一次,则进行上采样
                if level and i == num_res_blocks:
                    out_ch = ch
                    layers.append(
                        ResBlock(
                            ch,
                            time_embed_dim,
                            dropout,
                            out_channels=out_ch,
                            dims=dims,
                            use_checkpoint=use_checkpoint,
                            use_scale_shift_norm=use_scale_shift_norm,
                            up=True,
                        )
                        if resblock_updown
                        else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch)
                    )
                    ds //= 2
                self.output_blocks.append(TimestepEmbedSequential(*layers))
                self._feature_size += ch

        # 最后在输出部分进行一次卷积
        self.out = nn.Sequential(
            normalization(ch),
            nn.SiLU(),
            zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)),
        )
        if self.predict_codebook_ids:
            self.id_predictor = nn.Sequential(
            normalization(ch),
            conv_nd(dims, model_channels, n_embed, 1),
            #nn.LogSoftmax(dim=1)  # change to cross_entropy and produce non-normalized logits
        )

    def convert_to_fp16(self):
        """
        Convert the torso of the model to float16.
        """
        self.input_blocks.apply(convert_module_to_f16)
        self.middle_block.apply(convert_module_to_f16)
        self.output_blocks.apply(convert_module_to_f16)

    def convert_to_fp32(self):
        """
        Convert the torso of the model to float32.
        """
        self.input_blocks.apply(convert_module_to_f32)
        self.middle_block.apply(convert_module_to_f32)
        self.output_blocks.apply(convert_module_to_f32)

    def forward(self, x, timesteps=None, context=None, y=None,**kwargs):
        """
        Apply the model to an input batch.
        :param x: an [N x C x ...] Tensor of inputs.
        :param timesteps: a 1-D batch of timesteps.
        :param context: conditioning plugged in via crossattn
        :param y: an [N] Tensor of labels, if class-conditional.
        :return: an [N x C x ...] Tensor of outputs.
        """
        assert (y is not None) == (
            self.num_classes is not None
        ), "must specify y if and only if the model is class-conditional"
        hs      = []
        # 用于计算当前采样时间t的embedding
        t_emb   = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
        emb     = self.time_embed(t_emb)

        if self.num_classes is not None:
            assert y.shape == (x.shape[0],)
            emb = emb + self.label_emb(y)

        # 对输入模块进行循环,进行下采样并且融合时间特征与文本特征。
        h = x.type(self.dtype)
        for module in self.input_blocks:
            h = module(h, emb, context)
            hs.append(h)

        # 中间模块的特征提取
        h = self.middle_block(h, emb, context)

        # 上采样模块的特征提取
        for module in self.output_blocks:
            h = th.cat([h, hs.pop()], dim=1)
            h = module(h, emb, context)
        h = h.type(x.dtype)
        # 输出模块
        if self.predict_codebook_ids:
            return self.id_predictor(h)
        else:
            return self.out(h)

3. Hidden space decoding to generate pictures

insert image description here
Through the above steps, the results can be obtained by multiple sampling, and then we can generate pictures through latent space decoding.

The process of latent space decoding to generate pictures is very simple. Use the decode_first_stage method to generate pictures from the result of multiple sampling above.

In the decode_first_stage method, the network calls VAE to decode the obtained hidden vector of 64x64x3 to obtain a picture of 512x512x3.

@torch.no_grad()
def decode_first_stage(self, z, predict_cids=False, force_not_quantize=False):
    if predict_cids:
        if z.dim() == 4:
            z = torch.argmax(z.exp(), dim=1).long()
        z = self.first_stage_model.quantize.get_codebook_entry(z, shape=None)
        z = rearrange(z, 'b h w c -> b c h w').contiguous()

    z = 1. / self.scale_factor * z
	# 一般无需分割输入,所以直接将x_noisy传入self.model中,在下面else进行
    if hasattr(self, "split_input_params"):
    	......
    else:
        if isinstance(self.first_stage_model, VQModelInterface):
            return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize)
        else:
            return self.first_stage_model.decode(z)

Text to Image Prediction Process Code

The overall prediction code is as follows:

import random

import einops
import numpy as np
import torch
import cv2
import os
from ldm_hacked import DDIMSampler
from ldm_hacked import create_model, load_state_dict, DDIMSampler
from pytorch_lightning import seed_everything

# ----------------------- #
#   使用的参数
# ----------------------- #
# config的地址
config_path = "model_data/sd_v15.yaml"
# 模型的地址
model_path  = "model_data/v1-5-pruned-emaonly.safetensors"

# 生成的图像大小为input_shape
input_shape = [512, 512]
# 一次生成几张图像
num_samples = 2
# 采样的步数
ddim_steps  = 20
# 采样的种子,为-1的话则随机。
seed        = 12345
# eta
eta         = 0

# 提示词
prompt      = "a cat"
# 正面提示词
a_prompt    = "best quality, extremely detailed"
# 负面提示词
n_prompt    = "longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality"
# 正负扩大倍数
scale       = 9

# save_path
save_path   = "imgs/outputs_imgs"

# ----------------------- #
#   创建模型
# ----------------------- #
model   = create_model(config_path).cpu()
model.load_state_dict(load_state_dict(model_path, location='cuda'), strict=False)
model   = model.cuda()
ddim_sampler = DDIMSampler(model)

with torch.no_grad():
    if seed == -1:
        seed = random.randint(0, 65535)
    seed_everything(seed)

    # ----------------------- #
    #   获得编码后的prompt
    # ----------------------- #
    cond    = {
    
    "c_crossattn": [model.get_learned_conditioning([prompt + ', ' + a_prompt] * num_samples)]}
    un_cond = {
    
    "c_crossattn": [model.get_learned_conditioning([n_prompt] * num_samples)]}
    H, W    = input_shape
    shape   = (4, H // 8, W // 8)

    # ----------------------- #
    #   进行采样
    # ----------------------- #
    samples, intermediates = ddim_sampler.sample(ddim_steps, num_samples,
                                                    shape, cond, verbose=False, eta=eta,
                                                    unconditional_guidance_scale=scale,
                                                    unconditional_conditioning=un_cond)

    # ----------------------- #
    #   进行解码
    # ----------------------- #
    x_samples = model.decode_first_stage(samples)
    x_samples = (einops.rearrange(x_samples, 'b c h w -> b h w c') * 127.5 + 127.5).cpu().numpy().clip(0, 255).astype(np.uint8)

# ----------------------- #
#   保存图片
# ----------------------- #
if not os.path.exists(save_path):
    os.makedirs(save_path)
for index, image in enumerate(x_samples):
    cv2.imwrite(os.path.join(save_path, str(index) + ".jpg"), cv2.cvtColor(image, cv2.COLOR_BGR2RGB))

Guess you like

Origin blog.csdn.net/weixin_44791964/article/details/130588215