Aprendizagem do Modelo de Difusão por Difusão 3——Análise da Estrutura de Difusão Estável—Tomando a Imagem de Geração de Imagem (Picture Image, img2img) como exemplo

prefácio do estudo

Eu uso Stable Diffusion há muito tempo, mas nunca analisei sua estrutura interna direito, faça um blog para registrar isso, hehe.
insira a descrição da imagem aqui

Endereço de download do código-fonte

https://github.com/bubbliiiing/stable-diffusion

Você pode pedir uma estrela se quiser.

construção de rede

1. O que é difusão estável (SD)

Difusão estável é um modelo de difusão relativamente novo. A tradução é difusão estável. Embora o nome seja difusão estável, na verdade, os resultados gerados pela mudança da semente são completamente diferentes e muito instáveis.

A aplicação inicial do Stable Diffusion deve ser para gerar imagens a partir de texto, ou seja, gráficos de Vincent. Com o desenvolvimento da tecnologia, o Stable Diffusion não apenas suporta a geração de gráficos image2image, mas também suporta vários métodos de controle, como ControlNet, para personalizar o gerado imagens.

A Difusão Estável é baseada no modelo de difusão, portanto inclui inevitavelmente o processo de redução contínua de ruído. Se for uma imagem, também há um processo de adição contínua de ruído. Neste momento, a imagem antiga do DDPM é inseparável, conforme a seguir: Comparado com DDPM,
insira a descrição da imagem aqui
difusão estável O amostrador DDIM é usado, a difusão do espaço latente é usada e o grande conjunto de dados LAION-5B é usado para pré-treinamento.

Direct Finetune Stable Diffusion A maioria dos estudantes não deve ser capaz de cobrir o custo, mas Stable Diffusion tem muitas soluções leves de Finetune, como Lora, Textual Inversion, etc., mas esta é uma história posterior.

Este artigo analisa principalmente a composição estrutural de todo o modelo SD, o processo de uma difusão e múltiplas difusões.

Modelos grandes e AIGC são as tendências atuais da indústria, se você não souber fazer, será eliminado facilmente, hh.

O princípio do txt2img é mostrado na postagem do blog
Diffusion Diffusion Model Learning 2——Stable Diffusion Structure Analysis—Toking Text Generated Image (txt2img) como um exemplo
.

2. Composição da Difusão Estável

A difusão estável consiste em quatro partes.
1. Amostrador amostrador.
2. Autoencoder variacional (VAE) Autoencoder variacional.
3. Rede principal UNet, preditor de ruído.
4. Codificador de texto CLIPE mbedder.

Cada parte é muito importante, vamos pegar a geração de imagem como exemplo para analisar. Como é uma imagem para gerar uma imagem, temos duas entradas, uma é texto e a outra é uma imagem.

Três, processo de geração img2img

insira a descrição da imagem aqui
O processo de geração é dividido em quatro partes:
1. Execute a codificação VAE na imagem e adicione ruído de acordo com o valor de redução de ruído.
2. Prompt de codificação de texto.
3. Execute várias amostras de acordo com o valor de redução de ruído.
4. Use VAE para decodificar.

Em comparação com o gráfico de Wensheng, a entrada do gráfico gerado pelo gráfico foi alterada. Ele não é mais inicializado com ruído Gaussiano, mas com os recursos da imagem após a adição de ruído . Dessa forma, as informações são injetadas no modelo na forma de um imagem.

Em detalhes, como mostrado na figura acima:

  • A primeira etapa é usar a codificação VAE na imagem de entrada para obter o recurso Latente da imagem de entrada; em seguida, use o recurso Latente para adicionar ruído com base no DDIM Sampler e obtenha o recurso com ruído da imagem de entrada neste momento. Suponha que definamos o valor de redução de ruído para 0,8 e o número total de etapas seja 20; na primeira etapa, adicionaremos ruído à imagem de entrada 0,8x20 vezes e as 4 etapas restantes não serão adicionadas, o que pode ser entendido como atrapalhando 80% da imagem.Features, mantenha 20% das feições.
  • A segunda etapa é codificar o texto de entrada para obter recursos de texto;
  • A terceira etapa é amostrar as características com ruído obtidas na primeira etapa várias vezes de acordo com o valor de redução de ruído . Ainda tomando o valor de redução de ruído de 0,8 na primeira etapa como exemplo, adicionamos apenas 0,8x20 vezes de ruído , então precisamos realizar apenas 0,8x20 vezes de amostragem para restaurar a imagem .
  • A quarta etapa é restaurar a imagem amostrada usando o decodificador do VAE.
with torch.no_grad():
    if seed == -1:
        seed = random.randint(0, 65535)
    seed_everything(seed)

    # ----------------------- #
    #   对输入图片进行编码并加噪
    # ----------------------- #
    if image_path is not None:
        img = HWC3(np.array(img, np.uint8))
        img = torch.from_numpy(img.copy()).float().cuda() / 127.0 - 1.0
        img = torch.stack([img for _ in range(num_samples)], dim=0)
        img = einops.rearrange(img, 'b h w c -> b c h w').clone()

        ddim_sampler.make_schedule(ddim_steps, ddim_eta=eta, verbose=True)
        t_enc = min(int(denoise_strength * ddim_steps), ddim_steps - 1)
        z = model.get_first_stage_encoding(model.encode_first_stage(img))
        z_enc = ddim_sampler.stochastic_encode(z, torch.tensor([t_enc] * num_samples).to(model.device))

    # ----------------------- #
    #   获得编码后的prompt
    # ----------------------- #
    cond    = {
    
    "c_crossattn": [model.get_learned_conditioning([prompt + ', ' + a_prompt] * num_samples)]}
    un_cond = {
    
    "c_crossattn": [model.get_learned_conditioning([n_prompt] * num_samples)]}
    H, W    = input_shape
    shape   = (4, H // 8, W // 8)

    if image_path is not None:
        samples = ddim_sampler.decode(z_enc, cond, t_enc, unconditional_guidance_scale=scale, unconditional_conditioning=un_cond)
    else:
        # ----------------------- #
        #   进行采样
        # ----------------------- #
        samples, intermediates = ddim_sampler.sample(ddim_steps, num_samples,
                                                        shape, cond, verbose=False, eta=eta,
                                                        unconditional_guidance_scale=scale,
                                                        unconditional_conditioning=un_cond)

    # ----------------------- #
    #   进行解码
    # ----------------------- #
    x_samples = model.decode_first_stage(samples)
    x_samples = (einops.rearrange(x_samples, 'b c h w -> b h w c') * 127.5 + 127.5).cpu().numpy().clip(0, 255).astype(np.uint8)

1. Insira o código da imagem

insira a descrição da imagem aqui

No diagrama de geração de imagens, primeiro precisamos especificar uma imagem de referência e, em seguida, começar a trabalhar nessa imagem de referência:
1. Use o codificador VAE para codificar essa imagem de referência para que ela entre no espaço latente, somente quando entrar no espaço latente , a rede sabe qual é a imagem 2.
Em seguida, use o recurso latente para adicionar ruído com base no DDIM Sampler e, em seguida, obtenha o recurso com ruído da imagem de entrada. A lógica de adicionar ruído é a seguinte:

  • o denoise pode ser considerado como a proporção de reconstrução, 1 representa toda a reconstrução, 0 representa a reconstrução por etapas;
  • Suponha que definamos o valor de redução de ruído para 0,8 e o número total de etapas seja 20; adicionaremos ruído à imagem de entrada 0,8x20 vezes e as 4 etapas restantes não serão adicionadas, o que pode ser entendido como 80% dos recursos e 20% dos recursos; no entanto, mesmo depois de adicionar 20 etapas de ruído , as informações da imagem de entrada original ainda são um tanto reservadas, não completamente desreservadas.

Neste ponto, obteremos a imagem após adicionar ruído no espaço latente e, em seguida, faremos uma amostragem com base na imagem após adicionar ruído neste espaço latente .

2. Codificação de texto

insira a descrição da imagem aqui
A ideia de codificação de texto é relativamente simples, basta usar o codificador de texto CLIP para codificar diretamente. Uma categoria FrozenCLIPEmbedder é definida no código, e o CLIPTokenizer e CLIPTextModel da biblioteca de transformadores são usados.

No processo de pré-transferência, primeiro usamos CLIPTokenizer para codificar o texto de entrada e, em seguida, usamos CLIPTextModel para extração de recursos.Através do FrozenCLIPEmbedder, podemos obter um vetor de recursos de [batch_size, 77, 768].

class FrozenCLIPEmbedder(AbstractEncoder):
    """Uses the CLIP transformer encoder for text (from huggingface)"""
    LAYERS = [
        "last",
        "pooled",
        "hidden"
    ]
    def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77,
                 freeze=True, layer="last", layer_idx=None):  # clip-vit-base-patch32
        super().__init__()
        assert layer in self.LAYERS
        # 定义文本的tokenizer和transformer
        self.tokenizer      = CLIPTokenizer.from_pretrained(version)
        self.transformer    = CLIPTextModel.from_pretrained(version)
        self.device         = device
        self.max_length     = max_length
        # 冻结模型参数
        if freeze:
            self.freeze()
        self.layer = layer
        self.layer_idx = layer_idx
        if layer == "hidden":
            assert layer_idx is not None
            assert 0 <= abs(layer_idx) <= 12

    def freeze(self):
        self.transformer = self.transformer.eval()
        # self.train = disabled_train
        for param in self.parameters():
            param.requires_grad = False

    def forward(self, text):
        # 对输入的图片进行分词并编码,padding直接padding到77的长度。
        batch_encoding  = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,
                                        return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
        # 拿出input_ids然后传入transformer进行特征提取。
        tokens          = batch_encoding["input_ids"].to(self.device)
        outputs         = self.transformer(input_ids=tokens, output_hidden_states=self.layer=="hidden")
        # 取出所有的token
        if self.layer == "last":
            z = outputs.last_hidden_state
        elif self.layer == "pooled":
            z = outputs.pooler_output[:, None, :]
        else:
            z = outputs.hidden_states[self.layer_idx]
        return z

    def encode(self, text):
        return self(text)

3. Processo de amostragem

insira a descrição da imagem aqui

a. Gerar ruído inicial

Na imagem gerada pela imagem, nosso ruído inicial é obtido da imagem de referência, portanto, o ruído da imagem gerada pela imagem pode ser obtido referindo-se à primeira etapa

b. Amostragem do ruído N vezes

Como a difusão estável é um processo de difusão contínua, a redução contínua de ruído é indispensável, portanto, como remover o ruído é um problema.

No passo anterior, obtivemos um img, que é um vetor que obedece a uma distribuição normal, e começamos a desruir dele.

Vamos inverter o passo de tempo de ddim_timesteps, porque agora estamos eliminando ruído em vez de adicionar ruído e, em seguida, executar um loop nele. O código para o loop é o seguinte:

Existe uma máscara no loop, que é usada para reconstrução local e mascara os vetores ocultos de algumas áreas, que não é usada aqui . Todo o resto é um método ou uma função, e nada pode ser visto. O que mais se parece com o processo de amostragem é o método p_sample_ddim, precisamos entrar no método p_sample_ddim para ver.

for i, step in enumerate(iterator):
    # index是用来取得对应的调节参数的
    index   = total_steps - i - 1
    # 将步数拓展到bs维度
    ts      = torch.full((b,), step, device=device, dtype=torch.long)

    # 用于进行局部的重建,对部分区域的隐向量进行mask。
    if mask is not None:
        assert x0 is not None
        img_orig = self.model.q_sample(x0, ts)  # TODO: deterministic forward pass?
        img = img_orig * mask + (1. - mask) * img

    # 进行采样
    outs = self.p_sample_ddim(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps,
                                quantize_denoised=quantize_denoised, temperature=temperature,
                                noise_dropout=noise_dropout, score_corrector=score_corrector,
                                corrector_kwargs=corrector_kwargs,
                                unconditional_guidance_scale=unconditional_guidance_scale,
                                unconditional_conditioning=unconditional_conditioning)
    img, pred_x0 = outs
    # 回调函数
    if callback: callback(i)
    if img_callback: img_callback(pred_x0, i)

    if index % log_every_t == 0 or index == total_steps - 1:
        intermediates['x_inter'].append(img)
        intermediates['pred_x0'].append(pred_x0)

insira a descrição da imagem aqui

c. Análise de amostragem única

I. Ruído de previsão

Antes da amostragem de palavras, precisamos primeiro julgar se existe um prompt neg. Se for, precisamos processar o prompt neg ao mesmo tempo, caso contrário, precisamos processar apenas o prompt pos. No uso real, geralmente há um prompt negativo (o efeito será melhor), portanto, o processo de processamento correspondente é inserido por padrão.

Ao lidar com o prompt neg, copiamos o vetor oculto de entrada e o número da etapa, um pertence ao prompt pos e o outro pertence ao prompt neg. A dimensão de empilhamento padrão de arch.cat é 0, portanto, ela é empilhada na dimensão batch_size e as duas não afetarão uma à outra. Em seguida, empilhamos o prompt pos e o prompt neg em um lote, que também é empilhado na dimensão batch_size.

# 首先判断是否由neg prompt,unconditional_conditioning是由neg prompt获得的
if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
    e_t = self.model.apply_model(x, t, c)
else:
    # 一般都是有neg prompt的,所以进入到这里
    # 在这里我们对隐向量和步数进行复制,一个属于pos prompt,一个属于neg prompt
    # torch.cat默认堆叠维度为0,所以是在bs维度进行堆叠,二者不会互相影响
    x_in = torch.cat([x] * 2)
    t_in = torch.cat([t] * 2)
    # 然后我们将pos prompt和neg prompt堆叠到一个batch中
    if isinstance(c, dict):
        assert isinstance(unconditional_conditioning, dict)
        c_in = dict()
        for k in c:
            if isinstance(c[k], list):
                c_in[k] = [
                    torch.cat([unconditional_conditioning[k][i], c[k][i]])
                    for i in range(len(c[k]))
                ]
            else:
                c_in[k] = torch.cat([unconditional_conditioning[k], c[k]])
    else:
        c_in = torch.cat([unconditional_conditioning, c])

insira a descrição da imagem aqui
Após o empilhamento, passamos o vetor oculto, o número do passo e a condição de prompt para a rede juntos e dividimos o resultado na dimensão bs usando pedaços.

Porque quando estamos empilhando, o prompt neg é colocado na frente. Portanto, após a divisão, a primeira metade e_t_uncondé obtida usando o prompt neg e a segunda metade e_té obtida usando o prompt pos. Em essência, devemos expandir a influência do prompt pos e ficar longe da influência do prompt neg . Portanto, usamos e_t-e_t_uncondpara calcular a distância entre os dois e usamos a escala para expandir a distância entre os dois. Com base em e_t_uncond, o vetor oculto final é obtido.

# 堆叠完后,隐向量、步数和prompt条件一起传入网络中,将结果在bs维度进行使用chunk进行分割
e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)

insira a descrição da imagem aqui
O e_t obtido neste momento é o ruído de predição obtido conjuntamente pelo vetor oculto e pelo prompt.

II. Aplicação de ruído

É normal ouvir barulho? Obviamente que não, também precisamos adicionar o novo ruído obtido ao ruído original original em uma determinada proporção.

Neste local, é melhor combinarmos a fórmula em ddim, precisamos obter α ˉ t \bar{\alpha}_taˉtα ˉ t − 1 \bar{\alpha}_{t-1}aˉt - 1σ t \sigma_tpt1 − α ˉ t \sqrt{1-\bar{\alpha}_t}1aˉt .
insira a descrição da imagem aqui
insira a descrição da imagem aqui
No código, na verdade, pré-calculamos esses parâmetros. Só precisamos tirar diretamente, o a_t abaixo é o α ˉ t \bar{\alpha}_t fora dos colchetes na fórmulaaˉt, a_prev é α ˉ t − 1 \bar{\alpha}_{t-1} na fórmulaaˉt - 1, sigma_t é o σ t \sigma_t na fórmulapt, sqrt_one_minus_at é 1 − α ˉ t \sqrt{1-\bar{\alpha}_t} na fórmula1aˉt

# 根据采样器选择参数
alphas      = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
sigmas      = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas

# 根据步数选择参数,
# 这里的index就是上面循环中的total_steps - i - 1
a_t         = torch.full((b, 1, 1, 1), alphas[index], device=device)
a_prev      = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
sigma_t     = torch.full((b, 1, 1, 1), sigmas[index], device=device)
sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)

Na verdade, nesta etapa, apenas retiramos todos os coeficientes que a fórmula precisa usar para facilitar a adição, subtração, multiplicação e divisão subsequentes. Em seguida, implementamos a fórmula acima no código.

# current prediction for x_0
# 公式中的最左边
pred_x0             = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
if quantize_denoised:
    pred_x0, _, *_  = self.model.first_stage_model.quantize(pred_x0)
# direction pointing to x_t
# 公式的中间
dir_xt              = (1. - a_prev - sigma_t**2).sqrt() * e_t
# 公式最右边
noise               = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
if noise_dropout > 0.:
    noise           = torch.nn.functional.dropout(noise, p=noise_dropout)
x_prev              = a_prev.sqrt() * pred_x0 + dir_xt + noise
# 输出添加完公式的结果
return x_prev, pred_x0

insira a descrição da imagem aqui

d. Análise da estrutura de rede no processo de previsão de ruído

I. Análise do método apply_model

No processo de previsão de ruído em 3.a, usamos o método model.apply_model para prever ruído. O que esse método faz está oculto. Vamos dar uma olhada no trabalho específico.

O método apply_model está no arquivo ldm.models.diffusion.ddpm.py. Em apply_model, passamos x_noisy para self.model para prever o ruído.

x_recon = self.model(x_noisy, t, **cond)

insira a descrição da imagem aqui
self.model é uma classe pré-construída definida na linha 1416 do arquivo ldm.models.diffusion.ddpm.py, que contém a rede Unet de Stable Diffusion. A função de self.model é um pouco semelhante ao wrapper, que é selecionado de acordo com o modelo O método de fusão de recursos é usado para fundir o texto e o ruído gerado acima.

c_concat representa fusão usando empilhamento e c_crossattn representa fusão usando atenção.

class DiffusionWrapper(pl.LightningModule):
    def __init__(self, diff_model_config, conditioning_key):
        super().__init__()
        self.sequential_cross_attn = diff_model_config.pop("sequential_crossattn", False)
        # stable diffusion的unet网络
        self.diffusion_model = instantiate_from_config(diff_model_config)
        self.conditioning_key = conditioning_key
        assert self.conditioning_key in [None, 'concat', 'crossattn', 'hybrid', 'adm', 'hybrid-adm', 'crossattn-adm']

    def forward(self, x, t, c_concat: list = None, c_crossattn: list = None, c_adm=None):
        if self.conditioning_key is None:
            out = self.diffusion_model(x, t)
        elif self.conditioning_key == 'concat':
            xc = torch.cat([x] + c_concat, dim=1)
            out = self.diffusion_model(xc, t)
        elif self.conditioning_key == 'crossattn':
            if not self.sequential_cross_attn:
                cc = torch.cat(c_crossattn, 1)
            else:
                cc = c_crossattn
            out = self.diffusion_model(x, t, context=cc)
        elif self.conditioning_key == 'hybrid':
            xc = torch.cat([x] + c_concat, dim=1)
            cc = torch.cat(c_crossattn, 1)
            out = self.diffusion_model(xc, t, context=cc)
        elif self.conditioning_key == 'hybrid-adm':
            assert c_adm is not None
            xc = torch.cat([x] + c_concat, dim=1)
            cc = torch.cat(c_crossattn, 1)
            out = self.diffusion_model(xc, t, context=cc, y=c_adm)
        elif self.conditioning_key == 'crossattn-adm':
            assert c_adm is not None
            cc = torch.cat(c_crossattn, 1)
            out = self.diffusion_model(x, t, context=cc, y=c_adm)
        elif self.conditioning_key == 'adm':
            cc = c_crossattn[0]
            out = self.diffusion_model(x, t, y=cc)
        else:
            raise NotImplementedError()

        return out

insira a descrição da imagem aqui
O self.diffusion_model no código é a rede Unet de Stable Diffusion, e a estrutura da rede está localizada na classe UNetModel no arquivo ldm.modules.diffusionmodules.openaimodel.py.

II. Análise do modelo UNetModel

O principal trabalho do UNetModel é combinar o intervalo de tempo t e a incorporação de texto para calcular o ruído neste momento . Embora a ideia da UNet seja muito simples, na StableDiffusion, a UNetModel é composta pelos módulos ResBlock e Transformer, o que é mais complicado que a UNet comum como um todo.

O prompt obtém a incorporação de texto por meio do codificador de texto CLIP congelado e o Timesteps obtém a incorporação de etapas de tempo por meio de conexão completa (MLP);

O ResBlock é usado para combinar a incorporação de etapas de tempo e o módulo Transformer é usado para combinar a incorporação de texto.

Eu coloquei uma foto grande aqui, e os alunos podem ver as mudanças na forma interna.
Adicione uma descrição da imagem

O código Unet se parece com isso:

class UNetModel(nn.Module):
    """
    The full UNet model with attention and timestep embedding.
    :param in_channels: channels in the input Tensor.
    :param model_channels: base channel count for the model.
    :param out_channels: channels in the output Tensor.
    :param num_res_blocks: number of residual blocks per downsample.
    :param attention_resolutions: a collection of downsample rates at which
        attention will take place. May be a set, list, or tuple.
        For example, if this contains 4, then at 4x downsampling, attention
        will be used.
    :param dropout: the dropout probability.
    :param channel_mult: channel multiplier for each level of the UNet.
    :param conv_resample: if True, use learned convolutions for upsampling and
        downsampling.
    :param dims: determines if the signal is 1D, 2D, or 3D.
    :param num_classes: if specified (as an int), then this model will be
        class-conditional with `num_classes` classes.
    :param use_checkpoint: use gradient checkpointing to reduce memory usage.
    :param num_heads: the number of attention heads in each attention layer.
    :param num_heads_channels: if specified, ignore num_heads and instead use
                               a fixed channel width per attention head.
    :param num_heads_upsample: works with num_heads to set a different number
                               of heads for upsampling. Deprecated.
    :param use_scale_shift_norm: use a FiLM-like conditioning mechanism.
    :param resblock_updown: use residual blocks for up/downsampling.
    :param use_new_attention_order: use a different attention pattern for potentially
                                    increased efficiency.
    """

    def __init__(
        self,
        image_size,
        in_channels,
        model_channels,
        out_channels,
        num_res_blocks,
        attention_resolutions,
        dropout=0,
        channel_mult=(1, 2, 4, 8),
        conv_resample=True,
        dims=2,
        num_classes=None,
        use_checkpoint=False,
        use_fp16=False,
        num_heads=-1,
        num_head_channels=-1,
        num_heads_upsample=-1,
        use_scale_shift_norm=False,
        resblock_updown=False,
        use_new_attention_order=False,
        use_spatial_transformer=False,    # custom transformer support
        transformer_depth=1,              # custom transformer support
        context_dim=None,                 # custom transformer support
        n_embed=None,                     # custom support for prediction of discrete ids into codebook of first stage vq model
        legacy=True,
    ):
        super().__init__()
        if use_spatial_transformer:
            assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...'

        if context_dim is not None:
            assert use_spatial_transformer, 'Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...'
            from omegaconf.listconfig import ListConfig
            if type(context_dim) == ListConfig:
                context_dim = list(context_dim)

        if num_heads_upsample == -1:
            num_heads_upsample = num_heads

        if num_heads == -1:
            assert num_head_channels != -1, 'Either num_heads or num_head_channels has to be set'

        if num_head_channels == -1:
            assert num_heads != -1, 'Either num_heads or num_head_channels has to be set'

        self.image_size = image_size
        self.in_channels = in_channels
        self.model_channels = model_channels
        self.out_channels = out_channels
        self.num_res_blocks = num_res_blocks
        self.attention_resolutions = attention_resolutions
        self.dropout = dropout
        self.channel_mult = channel_mult
        self.conv_resample = conv_resample
        self.num_classes = num_classes
        self.use_checkpoint = use_checkpoint
        self.dtype = th.float16 if use_fp16 else th.float32
        self.num_heads = num_heads
        self.num_head_channels = num_head_channels
        self.num_heads_upsample = num_heads_upsample
        self.predict_codebook_ids = n_embed is not None

        # 用于计算当前采样时间t的embedding
        time_embed_dim  = model_channels * 4
        self.time_embed = nn.Sequential(
            linear(model_channels, time_embed_dim),
            nn.SiLU(),
            linear(time_embed_dim, time_embed_dim),
        )

        if self.num_classes is not None:
            self.label_emb = nn.Embedding(num_classes, time_embed_dim)
        
        # 定义输入模块的第一个卷积
        # TimestepEmbedSequential也可以看作一个包装器,根据层的种类进行时间或者文本的融合。
        self.input_blocks = nn.ModuleList(
            [
                TimestepEmbedSequential(
                    conv_nd(dims, in_channels, model_channels, 3, padding=1)
                )
            ]
        )
        self._feature_size  = model_channels
        input_block_chans   = [model_channels]
        ch                  = model_channels
        ds                  = 1
        # 对channel_mult进行循环,channel_mult一共有四个值,代表unet四个部分通道的扩张比例
        # [1, 2, 4, 4]
        for level, mult in enumerate(channel_mult):
            # 每个部分循环两次
            # 添加一个ResBlock和一个AttentionBlock
            for _ in range(num_res_blocks):
                # 先添加一个ResBlock
                # 用于对输入的噪声进行通道数的调整,并且融合t的特征
                layers = [
                    ResBlock(
                        ch,
                        time_embed_dim,
                        dropout,
                        out_channels=mult * model_channels,
                        dims=dims,
                        use_checkpoint=use_checkpoint,
                        use_scale_shift_norm=use_scale_shift_norm,
                    )
                ]
                # ch便是上述ResBlock的输出通道数
                ch = mult * model_channels
                if ds in attention_resolutions:
                    # num_heads=8
                    if num_head_channels == -1:
                        dim_head = ch // num_heads
                    else:
                        num_heads = ch // num_head_channels
                        dim_head = num_head_channels
                    if legacy:
                        #num_heads = 1
                        dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
                    # 使用了SpatialTransformer自注意力,加强全局特征,融合文本的特征
                    layers.append(
                        AttentionBlock(
                            ch,
                            use_checkpoint=use_checkpoint,
                            num_heads=num_heads,
                            num_head_channels=dim_head,
                            use_new_attention_order=use_new_attention_order,
                        ) if not use_spatial_transformer else SpatialTransformer(
                            ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim
                        )
                    )
                self.input_blocks.append(TimestepEmbedSequential(*layers))
                self._feature_size += ch
                input_block_chans.append(ch)
            # 如果不是四个部分中的最后一个部分,那么都要进行下采样。
            if level != len(channel_mult) - 1:
                out_ch = ch
                # 在此处进行下采样
                # 一般直接使用Downsample模块
                self.input_blocks.append(
                    TimestepEmbedSequential(
                        ResBlock(
                            ch,
                            time_embed_dim,
                            dropout,
                            out_channels=out_ch,
                            dims=dims,
                            use_checkpoint=use_checkpoint,
                            use_scale_shift_norm=use_scale_shift_norm,
                            down=True,
                        )
                        if resblock_updown
                        else Downsample(
                            ch, conv_resample, dims=dims, out_channels=out_ch
                        )
                    )
                )
                # 为下一阶段定义参数。
                ch = out_ch
                input_block_chans.append(ch)
                ds *= 2
                self._feature_size += ch

        if num_head_channels == -1:
            dim_head = ch // num_heads
        else:
            num_heads = ch // num_head_channels
            dim_head = num_head_channels
        if legacy:
            #num_heads = 1
            dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
        # 定义中间层
        # ResBlock + SpatialTransformer + ResBlock
        self.middle_block = TimestepEmbedSequential(
            ResBlock(
                ch,
                time_embed_dim,
                dropout,
                dims=dims,
                use_checkpoint=use_checkpoint,
                use_scale_shift_norm=use_scale_shift_norm,
            ),
            AttentionBlock(
                ch,
                use_checkpoint=use_checkpoint,
                num_heads=num_heads,
                num_head_channels=dim_head,
                use_new_attention_order=use_new_attention_order,
            ) if not use_spatial_transformer else SpatialTransformer(
                            ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim
                        ),
            ResBlock(
                ch,
                time_embed_dim,
                dropout,
                dims=dims,
                use_checkpoint=use_checkpoint,
                use_scale_shift_norm=use_scale_shift_norm,
            ),
        )
        self._feature_size += ch

        # 定义Unet上采样过程
        self.output_blocks = nn.ModuleList([])
        # 循环把channel_mult反了过来
        for level, mult in list(enumerate(channel_mult))[::-1]:
            # 上采样时每个部分循环三次
            for i in range(num_res_blocks + 1):
                ich = input_block_chans.pop()
                # 首先添加ResBlock层
                layers = [
                    ResBlock(
                        ch + ich,
                        time_embed_dim,
                        dropout,
                        out_channels=model_channels * mult,
                        dims=dims,
                        use_checkpoint=use_checkpoint,
                        use_scale_shift_norm=use_scale_shift_norm,
                    )
                ]
                ch = model_channels * mult
                # 然后进行SpatialTransformer自注意力
                if ds in attention_resolutions:
                    if num_head_channels == -1:
                        dim_head = ch // num_heads
                    else:
                        num_heads = ch // num_head_channels
                        dim_head = num_head_channels
                    if legacy:
                        #num_heads = 1
                        dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
                    layers.append(
                        AttentionBlock(
                            ch,
                            use_checkpoint=use_checkpoint,
                            num_heads=num_heads_upsample,
                            num_head_channels=dim_head,
                            use_new_attention_order=use_new_attention_order,
                        ) if not use_spatial_transformer else SpatialTransformer(
                            ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim
                        )
                    )
                # 如果不是channel_mult循环的第一个
                # 且
                # 是num_res_blocks循环的最后一次,则进行上采样
                if level and i == num_res_blocks:
                    out_ch = ch
                    layers.append(
                        ResBlock(
                            ch,
                            time_embed_dim,
                            dropout,
                            out_channels=out_ch,
                            dims=dims,
                            use_checkpoint=use_checkpoint,
                            use_scale_shift_norm=use_scale_shift_norm,
                            up=True,
                        )
                        if resblock_updown
                        else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch)
                    )
                    ds //= 2
                self.output_blocks.append(TimestepEmbedSequential(*layers))
                self._feature_size += ch

        # 最后在输出部分进行一次卷积
        self.out = nn.Sequential(
            normalization(ch),
            nn.SiLU(),
            zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)),
        )
        if self.predict_codebook_ids:
            self.id_predictor = nn.Sequential(
            normalization(ch),
            conv_nd(dims, model_channels, n_embed, 1),
            #nn.LogSoftmax(dim=1)  # change to cross_entropy and produce non-normalized logits
        )

    def convert_to_fp16(self):
        """
        Convert the torso of the model to float16.
        """
        self.input_blocks.apply(convert_module_to_f16)
        self.middle_block.apply(convert_module_to_f16)
        self.output_blocks.apply(convert_module_to_f16)

    def convert_to_fp32(self):
        """
        Convert the torso of the model to float32.
        """
        self.input_blocks.apply(convert_module_to_f32)
        self.middle_block.apply(convert_module_to_f32)
        self.output_blocks.apply(convert_module_to_f32)

    def forward(self, x, timesteps=None, context=None, y=None,**kwargs):
        """
        Apply the model to an input batch.
        :param x: an [N x C x ...] Tensor of inputs.
        :param timesteps: a 1-D batch of timesteps.
        :param context: conditioning plugged in via crossattn
        :param y: an [N] Tensor of labels, if class-conditional.
        :return: an [N x C x ...] Tensor of outputs.
        """
        assert (y is not None) == (
            self.num_classes is not None
        ), "must specify y if and only if the model is class-conditional"
        hs      = []
        # 用于计算当前采样时间t的embedding
        t_emb   = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
        emb     = self.time_embed(t_emb)

        if self.num_classes is not None:
            assert y.shape == (x.shape[0],)
            emb = emb + self.label_emb(y)

        # 对输入模块进行循环,进行下采样并且融合时间特征与文本特征。
        h = x.type(self.dtype)
        for module in self.input_blocks:
            h = module(h, emb, context)
            hs.append(h)

        # 中间模块的特征提取
        h = self.middle_block(h, emb, context)

        # 上采样模块的特征提取
        for module in self.output_blocks:
            h = th.cat([h, hs.pop()], dim=1)
            h = module(h, emb, context)
        h = h.type(x.dtype)
        # 输出模块
        if self.predict_codebook_ids:
            return self.id_predictor(h)
        else:
            return self.out(h)

4. Decodificação de espaço oculto para gerar imagens

insira a descrição da imagem aqui
Através das etapas acima, os resultados podem ser obtidos por amostragem múltipla, e então podemos gerar imagens através da decodificação do espaço latente.

O processo de decodificação do espaço latente para gerar imagens é muito simples. Use o método decode_first_stage para gerar imagens a partir dos resultados da amostragem múltipla acima.

No método decode_first_stage, a rede chama o VAE para decodificar o vetor oculto obtido de 64x64x3 para obter uma imagem de 512x512x3.

@torch.no_grad()
def decode_first_stage(self, z, predict_cids=False, force_not_quantize=False):
    if predict_cids:
        if z.dim() == 4:
            z = torch.argmax(z.exp(), dim=1).long()
        z = self.first_stage_model.quantize.get_codebook_entry(z, shape=None)
        z = rearrange(z, 'b h w c -> b c h w').contiguous()

    z = 1. / self.scale_factor * z
	# 一般无需分割输入,所以直接将x_noisy传入self.model中,在下面else进行
    if hasattr(self, "split_input_params"):
    	......
    else:
        if isinstance(self.first_stage_model, VQModelInterface):
            return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize)
        else:
            return self.first_stage_model.decode(z)

Código do processo de previsão de imagem para imagem

O código geral de previsão é o seguinte:

import os
import random

import cv2
import einops
import numpy as np
import torch
from PIL import Image
from pytorch_lightning import seed_everything

from ldm_hacked import *

# ----------------------- #
#   使用的参数
# ----------------------- #
# config的地址
config_path = "model_data/sd_v15.yaml"
# 模型的地址
model_path  = "model_data/v1-5-pruned-emaonly.safetensors"
# fp16,可以加速与节省显存
sd_fp16     = True
vae_fp16    = True

# ----------------------- #
#   生成图片的参数
# ----------------------- #
# 生成的图像大小为input_shape,对于img2img会进行Centter Crop
input_shape = [512, 512]
# 一次生成几张图像
num_samples = 1
# 采样的步数
ddim_steps  = 20
# 采样的种子,为-1的话则随机。
seed        = 12345
# eta
eta         = 0
# denoise强度,for img2img
denoise_strength = 1.0

# ----------------------- #
#   提示词相关参数
# ----------------------- #
# 提示词
prompt      = "a cute cat, with yellow leaf, trees"
# 正面提示词
a_prompt    = "best quality, extremely detailed"
# 负面提示词
n_prompt    = "longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality"
# 正负扩大倍数
scale       = 9
# img2img使用,如果不想img2img这设置为None。
image_path  = None

# ----------------------- #
#   保存路径
# ----------------------- #
save_path   = "imgs/outputs_imgs"

# ----------------------- #
#   创建模型
# ----------------------- #
model   = create_model(config_path).cpu()
model.load_state_dict(load_state_dict(model_path, location='cuda'), strict=False)
model   = model.cuda()
ddim_sampler = DDIMSampler(model)
if sd_fp16:
    model = model.half()

if image_path is not None:
    img = Image.open(image_path)
    img = crop_and_resize(img, input_shape[0], input_shape[1])

with torch.no_grad():
    if seed == -1:
        seed = random.randint(0, 65535)
    seed_everything(seed)
    
    # ----------------------- #
    #   对输入图片进行编码并加噪
    # ----------------------- #
    if image_path is not None:
        img = HWC3(np.array(img, np.uint8))
        img = torch.from_numpy(img.copy()).float().cuda() / 127.0 - 1.0
        img = torch.stack([img for _ in range(num_samples)], dim=0)
        img = einops.rearrange(img, 'b h w c -> b c h w').clone()
        if vae_fp16:
            img = img.half()
            model.first_stage_model = model.first_stage_model.half()
        else:
            model.first_stage_model = model.first_stage_model.float()

        ddim_sampler.make_schedule(ddim_steps, ddim_eta=eta, verbose=True)
        t_enc = min(int(denoise_strength * ddim_steps), ddim_steps - 1)
        z = model.get_first_stage_encoding(model.encode_first_stage(img))
        z_enc = ddim_sampler.stochastic_encode(z, torch.tensor([t_enc] * num_samples).to(model.device))
        z_enc = z_enc.half() if sd_fp16 else z_enc.float()

    # ----------------------- #
    #   获得编码后的prompt
    # ----------------------- #
    cond    = {
    
    "c_crossattn": [model.get_learned_conditioning([prompt + ', ' + a_prompt] * num_samples)]}
    un_cond = {
    
    "c_crossattn": [model.get_learned_conditioning([n_prompt] * num_samples)]}
    H, W    = input_shape
    shape   = (4, H // 8, W // 8)

    if image_path is not None:
        samples = ddim_sampler.decode(z_enc, cond, t_enc, unconditional_guidance_scale=scale, unconditional_conditioning=un_cond)
    else:
        # ----------------------- #
        #   进行采样
        # ----------------------- #
        samples, intermediates = ddim_sampler.sample(ddim_steps, num_samples,
                                                        shape, cond, verbose=False, eta=eta,
                                                        unconditional_guidance_scale=scale,
                                                        unconditional_conditioning=un_cond)

    # ----------------------- #
    #   进行解码
    # ----------------------- #
    x_samples = model.decode_first_stage(samples.half() if vae_fp16 else samples.float())

    x_samples = (einops.rearrange(x_samples, 'b c h w -> b h w c') * 127.5 + 127.5).cpu().numpy().clip(0, 255).astype(np.uint8)

# ----------------------- #
#   保存图片
# ----------------------- #
if not os.path.exists(save_path):
    os.makedirs(save_path)
for index, image in enumerate(x_samples):
    cv2.imwrite(os.path.join(save_path, str(index) + ".jpg"), cv2.cvtColor(image, cv2.COLOR_BGR2RGB))

Acho que você gosta

Origin blog.csdn.net/weixin_44791964/article/details/131992399
Recomendado
Clasificación