Aprendizaje del modelo de difusión por difusión 2——Análisis de la estructura de difusión estable—tomando texto para generar imágenes como ejemplo

Prefacio de estudio

Hace tiempo que uso Stable Diffusion, pero nunca he analizado bien su estructura interna, escribe un blog para registrarlo, jeje.
inserte la descripción de la imagen aquí

Dirección de descarga del código fuente

https://github.com/bubbliiiing/stable-diffusion

Puedes pedir una estrella si te gusta.

construcción de redes

1. ¿Qué es la difusión estable (SD)?

Stable Diffusion es un modelo de difusión relativamente nuevo. La traducción es difusión estable. Aunque el nombre es difusión estable, de hecho, los resultados generados al cambiar la semilla son completamente diferentes y muy inestables.

La aplicación inicial de Stable Diffusion debería ser generar imágenes a partir de texto, es decir, gráficos de Vincent. Con el desarrollo de la tecnología, Stable Diffusion no solo admite la generación de gráficos imagen a imagen, sino que también admite varios métodos de control, como ControlNet, para personalizar los gráficos generados. imágenes

La difusión estable se basa en el modelo de difusión, por lo que inevitablemente incluye el proceso de eliminación continua de ruido. Si se trata de una imagen, también hay un proceso de adición continua de ruido. En este momento, la imagen anterior de DDPM es inseparable, de la siguiente manera: En comparación con DDPM,
inserte la descripción de la imagen aquí
Difusión estable Se usa el muestreador DDIM, se usa la difusión del espacio latente y el gran conjunto de datos LAION-5B se usa para el entrenamiento previo.

Direct Finetune Stable Diffusion La mayoría de los estudiantes no deberían poder cubrir el costo, pero Stable Diffusion tiene muchas soluciones ligeras de Finetune, como Lora, Textual Inversion, etc., pero esta es una historia posterior.

Este artículo analiza principalmente la composición estructural de todo el modelo SD, el proceso de difusión única y difusión múltiple.

Los modelos grandes y AIGC son las tendencias actuales de la industria, si no sabes cómo hacerlo, serás eliminado fácilmente, hh.

2. Composición de difusión estable

Stable Diffusion consta de cuatro partes.
1. Muestreador de muestras.
2. Autocodificador variacional (VAE) Autocodificador variacional.
3. Red principal UNet, predictor de ruido.
4. Codificador de texto CLIPE mbedder.

Cada parte es muy importante, primero tomamos texto para generar imágenes como ejemplo para analizar. Dado que la imagen se genera a partir de texto, nuestra entrada es solo texto y no hay ninguna imagen de entrada en este momento.

3. Proceso de generación

inserte la descripción de la imagen aquí
El proceso de generación se divide en tres partes:
1. Codificación de texto rápido.
2. Tomar varias muestras.
3. Decodificar.

with torch.no_grad():
    if seed == -1:
        seed = random.randint(0, 65535)
    seed_everything(seed)

    # ----------------------- #
    #   获得编码后的prompt
    # ----------------------- #
    cond    = {
    
    "c_crossattn": [model.get_learned_conditioning([prompt + ', ' + a_prompt] * num_samples)]}
    un_cond = {
    
    "c_crossattn": [model.get_learned_conditioning([n_prompt] * num_samples)]}
    H, W    = input_shape
    shape   = (4, H // 8, W // 8)

    # ----------------------- #
    #   进行采样
    # ----------------------- #
    samples, intermediates = ddim_sampler.sample(ddim_steps, num_samples,
                                                    shape, cond, verbose=False, eta=eta,
                                                    unconditional_guidance_scale=scale,
                                                    unconditional_conditioning=un_cond)

    # ----------------------- #
    #   进行解码
    # ----------------------- #
    x_samples = model.decode_first_stage(samples)
    x_samples = (einops.rearrange(x_samples, 'b c h w -> b h w c') * 127.5 + 127.5).cpu().numpy().clip(0, 255).astype(np.uint8)

1. Codificación de texto

inserte la descripción de la imagen aquí
La idea de la codificación de texto es relativamente simple, solo use el codificador de texto CLIP para codificar directamente.Se define una categoría FrozenCLIPEmbedder en el código, y se usan CLIPTokenizer y CLIPTextModel de la biblioteca de transformadores.

En el proceso previo a la transferencia, primero usamos CLIPTokenizer para codificar el texto de entrada y luego usamos CLIPTextModel para la extracción de características. A través de FrozenCLIPEmbedder, podemos obtener un vector de características de [batch_size, 77, 768].

class FrozenCLIPEmbedder(AbstractEncoder):
    """Uses the CLIP transformer encoder for text (from huggingface)"""
    LAYERS = [
        "last",
        "pooled",
        "hidden"
    ]
    def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77,
                 freeze=True, layer="last", layer_idx=None):  # clip-vit-base-patch32
        super().__init__()
        assert layer in self.LAYERS
        # 定义文本的tokenizer和transformer
        self.tokenizer      = CLIPTokenizer.from_pretrained(version)
        self.transformer    = CLIPTextModel.from_pretrained(version)
        self.device         = device
        self.max_length     = max_length
        # 冻结模型参数
        if freeze:
            self.freeze()
        self.layer = layer
        self.layer_idx = layer_idx
        if layer == "hidden":
            assert layer_idx is not None
            assert 0 <= abs(layer_idx) <= 12

    def freeze(self):
        self.transformer = self.transformer.eval()
        # self.train = disabled_train
        for param in self.parameters():
            param.requires_grad = False

    def forward(self, text):
        # 对输入的图片进行分词并编码,padding直接padding到77的长度。
        batch_encoding  = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,
                                        return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
        # 拿出input_ids然后传入transformer进行特征提取。
        tokens          = batch_encoding["input_ids"].to(self.device)
        outputs         = self.transformer(input_ids=tokens, output_hidden_states=self.layer=="hidden")
        # 取出所有的token
        if self.layer == "last":
            z = outputs.last_hidden_state
        elif self.layer == "pooled":
            z = outputs.pooler_output[:, None, :]
        else:
            z = outputs.hidden_states[self.layer_idx]
        return z

    def encode(self, text):
        return self(text)

2. Proceso de muestreo

inserte la descripción de la imagen aquí

a. Generar ruido inicial

Dado que solo hay texto en la entrada y no hay imagen de entrada, ¿de dónde viene el ruido inicial?

Es suficiente crear directamente una distribución normal de ruido aquí. La comprensión simple es: dado que la imagen original se agrega continuamente con ruido de distribución normal durante el entrenamiento para obtener la matriz de ruido final, entonces inicializo directamente una distribución normal de ruido Es razonable para generar imágenes como ruido inicial .

En realidad, esto se hace en el código, pero debido a que difundimos en el espacio oculto , el ruido que generamos también es relativo al espacio oculto .

Aquí hay una breve introducción a VAE. VAE es un codificador automático variacional, que puede codificar imágenes de entrada. Una imagen con una altura y un ancho de 512x512x3 se convertirá en 64x64x4 después de codificarse con VAE. Este 4 está configurado artificialmente, así que no se preocupe por es por qué no 3 . En este momento, usamos una matriz simple para reemplazar la imagen original de 512x512x3 y los costos de transmisión y almacenamiento son muy bajos. Cuando realmente quiera verlo, puede decodificar la matriz de 64x64x4 para obtener una imagen de 512x512x3.

Por lo tanto, si el ruido que generamos es relativo al espacio oculto y queremos generar una imagen de 512 x 512 x 3, entonces necesitamos inicializar un vector oculto de 64 x 64 x 4. Después de difundir el espacio latente, podemos usar el decodificador para generar una imagen de 512 x 512 x 3.

En el código, hacemos exactamente eso, el código de generación de ruido inicial es:

img = torch.randn(shape, device=device)

El código está en el método ddim_sampling en ldm.models.diffusion.ddim.py. La forma se pasa desde el exterior y el tamaño es [4, 64, 64].
inserte la descripción de la imagen aquí

b. Muestrear el ruido N veces

Dado que la difusión estable es un proceso de difusión continua, la eliminación de ruido continua es indispensable, por lo que la eliminación de ruido es un problema.

En el paso anterior hemos obtenido un img, que es un vector que se ajusta a una distribución normal, y a partir de él empezamos a eliminar el ruido.

Invertiremos el paso de tiempo de ddim_timesteps, porque ahora estamos eliminando ruido en lugar de agregar ruido, y luego realizaremos un ciclo en él. El código para el ciclo es el siguiente:

Hay una máscara en el bucle, que se usa para la reconstrucción local y enmascara los vectores ocultos de algunas áreas, que no se usa aquí . Todo lo demás es un método o una función, y no se ve nada. El que más se parece al proceso de muestreo es el método p_sample_ddim, necesitamos ingresar al método p_sample_ddim para ver.

for i, step in enumerate(iterator):
    # index是用来取得对应的调节参数的
    index   = total_steps - i - 1
    # 将步数拓展到bs维度
    ts      = torch.full((b,), step, device=device, dtype=torch.long)

    # 用于进行局部的重建,对部分区域的隐向量进行mask。
    if mask is not None:
        assert x0 is not None
        img_orig = self.model.q_sample(x0, ts)  # TODO: deterministic forward pass?
        img = img_orig * mask + (1. - mask) * img

    # 进行采样
    outs = self.p_sample_ddim(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps,
                                quantize_denoised=quantize_denoised, temperature=temperature,
                                noise_dropout=noise_dropout, score_corrector=score_corrector,
                                corrector_kwargs=corrector_kwargs,
                                unconditional_guidance_scale=unconditional_guidance_scale,
                                unconditional_conditioning=unconditional_conditioning)
    img, pred_x0 = outs
    # 回调函数
    if callback: callback(i)
    if img_callback: img_callback(pred_x0, i)

    if index % log_every_t == 0 or index == total_steps - 1:
        intermediates['x_inter'].append(img)
        intermediates['pred_x0'].append(pred_x0)

inserte la descripción de la imagen aquí

c. Análisis de muestreo único

I. Ruido de predicción

Antes del muestreo de palabras, primero debemos juzgar si hay un indicador de neg. Si es así, debemos procesar el indicador de neg al mismo tiempo, de lo contrario, solo necesitamos procesar el indicador de pos. En el uso real, generalmente hay un indicador negativo (el efecto será mejor), por lo que el proceso de procesamiento correspondiente se ingresa de forma predeterminada.

Cuando se trata de una solicitud de neg, copiamos el vector oculto de entrada y el número de paso, uno pertenece a la solicitud de pos y el otro pertenece a la solicitud de neg. La dimensión de apilamiento predeterminada de torch.cat es 0, por lo que se apila en la dimensión batch_size y las dos no se afectarán entre sí. Luego, apilamos el indicador pos y el indicador neg en un lote, que también se apila en la dimensión batch_size.

# 首先判断是否由neg prompt,unconditional_conditioning是由neg prompt获得的
if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
    e_t = self.model.apply_model(x, t, c)
else:
    # 一般都是有neg prompt的,所以进入到这里
    # 在这里我们对隐向量和步数进行复制,一个属于pos prompt,一个属于neg prompt
    # torch.cat默认堆叠维度为0,所以是在bs维度进行堆叠,二者不会互相影响
    x_in = torch.cat([x] * 2)
    t_in = torch.cat([t] * 2)
    # 然后我们将pos prompt和neg prompt堆叠到一个batch中
    if isinstance(c, dict):
        assert isinstance(unconditional_conditioning, dict)
        c_in = dict()
        for k in c:
            if isinstance(c[k], list):
                c_in[k] = [
                    torch.cat([unconditional_conditioning[k][i], c[k][i]])
                    for i in range(len(c[k]))
                ]
            else:
                c_in[k] = torch.cat([unconditional_conditioning[k], c[k]])
    else:
        c_in = torch.cat([unconditional_conditioning, c])

inserte la descripción de la imagen aquí
Después de apilar, pasamos juntos el vector oculto, el número de paso y la condición de solicitud a la red, y dividimos el resultado en la dimensión bs usando fragmentos.

Porque cuando estamos apilando, el indicador neg se coloca al frente. Por lo tanto, después de la división, la primera mitad e_t_uncondse obtiene usando el mensaje neg, y la segunda mitad e_tse obtiene usando el mensaje pos. En esencia, debemos expandir la influencia del mensaje pos y alejarnos de la influencia del mensaje neg. rápido _ Por lo tanto, usamos e_t-e_t_uncondpara calcular la distancia entre los dos y usamos la escala para expandir la distancia entre los dos. Sobre la base de e_t_uncond, se obtiene el vector oculto final.

# 堆叠完后,隐向量、步数和prompt条件一起传入网络中,将结果在bs维度进行使用chunk进行分割
e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)

inserte la descripción de la imagen aquí
El e_t obtenido en este momento es el ruido de predicción obtenido conjuntamente por el vector oculto y el aviso.

II. Aplicación de ruido

¿Está bien escuchar el ruido? Obviamente no, también necesitamos agregar el nuevo ruido obtenido al ruido original original en cierta proporción.

En este lugar, será mejor que combinemos la fórmula en ddim, necesitamos obtener α ˉ t \bar{\alpha}_taˉtα ˉ t − 1 \bar{\alpha}_{t-1}aˉt 1σ t \sigma_tpagt1 − α ˉ t \sqrt{1-\bar{\alpha}_t}1aˉt .
inserte la descripción de la imagen aquí
inserte la descripción de la imagen aquí
En el código, en realidad hemos precalculado estos parámetros. Solo necesitamos sacarlo directamente, el a_t debajo es el α ˉ t \bar{\alpha}_t fuera de los corchetes en la fórmulaaˉt, a_prev es α ˉ t − 1 \bar{\alpha}_{t-1} en la fórmulaaˉt 1, sigma_t es el σ t \sigma_t en la fórmulapagt, sqrt_one_minus_at es 1 − α ˉ t \sqrt{1-\bar{\alpha}_t} en la fórmula1aˉt

# 根据采样器选择参数
alphas      = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
sigmas      = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas

# 根据步数选择参数,
# 这里的index就是上面循环中的total_steps - i - 1
a_t         = torch.full((b, 1, 1, 1), alphas[index], device=device)
a_prev      = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
sigma_t     = torch.full((b, 1, 1, 1), sigmas[index], device=device)
sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)

De hecho, en este paso, simplemente sacamos todos los coeficientes que la fórmula necesita usar para facilitar las sumas, restas, multiplicaciones y divisiones posteriores. Luego implementamos la fórmula anterior en el código.

# current prediction for x_0
# 公式中的最左边
pred_x0             = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
if quantize_denoised:
    pred_x0, _, *_  = self.model.first_stage_model.quantize(pred_x0)
# direction pointing to x_t
# 公式的中间
dir_xt              = (1. - a_prev - sigma_t**2).sqrt() * e_t
# 公式最右边
noise               = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
if noise_dropout > 0.:
    noise           = torch.nn.functional.dropout(noise, p=noise_dropout)
x_prev              = a_prev.sqrt() * pred_x0 + dir_xt + noise
# 输出添加完公式的结果
return x_prev, pred_x0

inserte la descripción de la imagen aquí

D. Análisis de la estructura de la red en el proceso de predicción del ruido.

I. Análisis del método apply_model

En el proceso de predicción del ruido en 3.a, usamos el método model.apply_model para predecir el ruido. Lo que hace este método está oculto. Echemos un vistazo al trabajo específico.

El método apply_model se encuentra en el archivo ldm.models.diffusion.ddpm.py. En apply_model, pasamos x_noisy a self.model para predecir el ruido.

x_recon = self.model(x_noisy, t, **cond)

inserte la descripción de la imagen aquí
self.model es una clase predefinida definida en la línea 1416 del archivo ldm.models.diffusion.ddpm.py, que contiene la red Unet de Stable Diffusion. La función de self.model es algo similar al contenedor, que es seleccionado según el modelo El método de fusión de características se utiliza para fusionar el texto y el ruido generado anteriormente.

c_concat representa la fusión mediante el apilamiento y c_crossattn representa la fusión mediante la atención.

class DiffusionWrapper(pl.LightningModule):
    def __init__(self, diff_model_config, conditioning_key):
        super().__init__()
        self.sequential_cross_attn = diff_model_config.pop("sequential_crossattn", False)
        # stable diffusion的unet网络
        self.diffusion_model = instantiate_from_config(diff_model_config)
        self.conditioning_key = conditioning_key
        assert self.conditioning_key in [None, 'concat', 'crossattn', 'hybrid', 'adm', 'hybrid-adm', 'crossattn-adm']

    def forward(self, x, t, c_concat: list = None, c_crossattn: list = None, c_adm=None):
        if self.conditioning_key is None:
            out = self.diffusion_model(x, t)
        elif self.conditioning_key == 'concat':
            xc = torch.cat([x] + c_concat, dim=1)
            out = self.diffusion_model(xc, t)
        elif self.conditioning_key == 'crossattn':
            if not self.sequential_cross_attn:
                cc = torch.cat(c_crossattn, 1)
            else:
                cc = c_crossattn
            out = self.diffusion_model(x, t, context=cc)
        elif self.conditioning_key == 'hybrid':
            xc = torch.cat([x] + c_concat, dim=1)
            cc = torch.cat(c_crossattn, 1)
            out = self.diffusion_model(xc, t, context=cc)
        elif self.conditioning_key == 'hybrid-adm':
            assert c_adm is not None
            xc = torch.cat([x] + c_concat, dim=1)
            cc = torch.cat(c_crossattn, 1)
            out = self.diffusion_model(xc, t, context=cc, y=c_adm)
        elif self.conditioning_key == 'crossattn-adm':
            assert c_adm is not None
            cc = torch.cat(c_crossattn, 1)
            out = self.diffusion_model(x, t, context=cc, y=c_adm)
        elif self.conditioning_key == 'adm':
            cc = c_crossattn[0]
            out = self.diffusion_model(x, t, y=cc)
        else:
            raise NotImplementedError()

        return out

inserte la descripción de la imagen aquí
El self.diffusion_model en el código es la red Unet de Stable Diffusion, y la estructura de la red se encuentra en la clase UNetModel en el archivo ldm.modules.diffusionmodules.openaimodel.py.

II. Análisis del modelo UNetModel

El trabajo principal de UNetModel es combinar el paso de tiempo ty la incrustación de texto para calcular el ruido en este momento . Aunque la idea de UNet es muy simple, en StableDiffusion, UNetModel se compone de módulos ResBlock y Transformer, que es más complicado que UNet ordinario en su conjunto.

Prompt obtiene Text Embedding a través de Frozen CLIP Text Encoder, y Timesteps obtiene Timesteps Embedding a través de conexión completa (MLP);

ResBlock se usa para combinar la incrustación de pasos de tiempo, y el módulo Transformer se usa para combinar la incrustación de texto.

Pongo una imagen grande aquí, y los estudiantes pueden ver los cambios en la forma interna.
Por favor agregue una descripción de la imagen

El código de Unet se ve así:

class UNetModel(nn.Module):
    """
    The full UNet model with attention and timestep embedding.
    :param in_channels: channels in the input Tensor.
    :param model_channels: base channel count for the model.
    :param out_channels: channels in the output Tensor.
    :param num_res_blocks: number of residual blocks per downsample.
    :param attention_resolutions: a collection of downsample rates at which
        attention will take place. May be a set, list, or tuple.
        For example, if this contains 4, then at 4x downsampling, attention
        will be used.
    :param dropout: the dropout probability.
    :param channel_mult: channel multiplier for each level of the UNet.
    :param conv_resample: if True, use learned convolutions for upsampling and
        downsampling.
    :param dims: determines if the signal is 1D, 2D, or 3D.
    :param num_classes: if specified (as an int), then this model will be
        class-conditional with `num_classes` classes.
    :param use_checkpoint: use gradient checkpointing to reduce memory usage.
    :param num_heads: the number of attention heads in each attention layer.
    :param num_heads_channels: if specified, ignore num_heads and instead use
                               a fixed channel width per attention head.
    :param num_heads_upsample: works with num_heads to set a different number
                               of heads for upsampling. Deprecated.
    :param use_scale_shift_norm: use a FiLM-like conditioning mechanism.
    :param resblock_updown: use residual blocks for up/downsampling.
    :param use_new_attention_order: use a different attention pattern for potentially
                                    increased efficiency.
    """

    def __init__(
        self,
        image_size,
        in_channels,
        model_channels,
        out_channels,
        num_res_blocks,
        attention_resolutions,
        dropout=0,
        channel_mult=(1, 2, 4, 8),
        conv_resample=True,
        dims=2,
        num_classes=None,
        use_checkpoint=False,
        use_fp16=False,
        num_heads=-1,
        num_head_channels=-1,
        num_heads_upsample=-1,
        use_scale_shift_norm=False,
        resblock_updown=False,
        use_new_attention_order=False,
        use_spatial_transformer=False,    # custom transformer support
        transformer_depth=1,              # custom transformer support
        context_dim=None,                 # custom transformer support
        n_embed=None,                     # custom support for prediction of discrete ids into codebook of first stage vq model
        legacy=True,
    ):
        super().__init__()
        if use_spatial_transformer:
            assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...'

        if context_dim is not None:
            assert use_spatial_transformer, 'Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...'
            from omegaconf.listconfig import ListConfig
            if type(context_dim) == ListConfig:
                context_dim = list(context_dim)

        if num_heads_upsample == -1:
            num_heads_upsample = num_heads

        if num_heads == -1:
            assert num_head_channels != -1, 'Either num_heads or num_head_channels has to be set'

        if num_head_channels == -1:
            assert num_heads != -1, 'Either num_heads or num_head_channels has to be set'

        self.image_size = image_size
        self.in_channels = in_channels
        self.model_channels = model_channels
        self.out_channels = out_channels
        self.num_res_blocks = num_res_blocks
        self.attention_resolutions = attention_resolutions
        self.dropout = dropout
        self.channel_mult = channel_mult
        self.conv_resample = conv_resample
        self.num_classes = num_classes
        self.use_checkpoint = use_checkpoint
        self.dtype = th.float16 if use_fp16 else th.float32
        self.num_heads = num_heads
        self.num_head_channels = num_head_channels
        self.num_heads_upsample = num_heads_upsample
        self.predict_codebook_ids = n_embed is not None

        # 用于计算当前采样时间t的embedding
        time_embed_dim  = model_channels * 4
        self.time_embed = nn.Sequential(
            linear(model_channels, time_embed_dim),
            nn.SiLU(),
            linear(time_embed_dim, time_embed_dim),
        )

        if self.num_classes is not None:
            self.label_emb = nn.Embedding(num_classes, time_embed_dim)
        
        # 定义输入模块的第一个卷积
        # TimestepEmbedSequential也可以看作一个包装器,根据层的种类进行时间或者文本的融合。
        self.input_blocks = nn.ModuleList(
            [
                TimestepEmbedSequential(
                    conv_nd(dims, in_channels, model_channels, 3, padding=1)
                )
            ]
        )
        self._feature_size  = model_channels
        input_block_chans   = [model_channels]
        ch                  = model_channels
        ds                  = 1
        # 对channel_mult进行循环,channel_mult一共有四个值,代表unet四个部分通道的扩张比例
        # [1, 2, 4, 4]
        for level, mult in enumerate(channel_mult):
            # 每个部分循环两次
            # 添加一个ResBlock和一个AttentionBlock
            for _ in range(num_res_blocks):
                # 先添加一个ResBlock
                # 用于对输入的噪声进行通道数的调整,并且融合t的特征
                layers = [
                    ResBlock(
                        ch,
                        time_embed_dim,
                        dropout,
                        out_channels=mult * model_channels,
                        dims=dims,
                        use_checkpoint=use_checkpoint,
                        use_scale_shift_norm=use_scale_shift_norm,
                    )
                ]
                # ch便是上述ResBlock的输出通道数
                ch = mult * model_channels
                if ds in attention_resolutions:
                    # num_heads=8
                    if num_head_channels == -1:
                        dim_head = ch // num_heads
                    else:
                        num_heads = ch // num_head_channels
                        dim_head = num_head_channels
                    if legacy:
                        #num_heads = 1
                        dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
                    # 使用了SpatialTransformer自注意力,加强全局特征,融合文本的特征
                    layers.append(
                        AttentionBlock(
                            ch,
                            use_checkpoint=use_checkpoint,
                            num_heads=num_heads,
                            num_head_channels=dim_head,
                            use_new_attention_order=use_new_attention_order,
                        ) if not use_spatial_transformer else SpatialTransformer(
                            ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim
                        )
                    )
                self.input_blocks.append(TimestepEmbedSequential(*layers))
                self._feature_size += ch
                input_block_chans.append(ch)
            # 如果不是四个部分中的最后一个部分,那么都要进行下采样。
            if level != len(channel_mult) - 1:
                out_ch = ch
                # 在此处进行下采样
                # 一般直接使用Downsample模块
                self.input_blocks.append(
                    TimestepEmbedSequential(
                        ResBlock(
                            ch,
                            time_embed_dim,
                            dropout,
                            out_channels=out_ch,
                            dims=dims,
                            use_checkpoint=use_checkpoint,
                            use_scale_shift_norm=use_scale_shift_norm,
                            down=True,
                        )
                        if resblock_updown
                        else Downsample(
                            ch, conv_resample, dims=dims, out_channels=out_ch
                        )
                    )
                )
                # 为下一阶段定义参数。
                ch = out_ch
                input_block_chans.append(ch)
                ds *= 2
                self._feature_size += ch

        if num_head_channels == -1:
            dim_head = ch // num_heads
        else:
            num_heads = ch // num_head_channels
            dim_head = num_head_channels
        if legacy:
            #num_heads = 1
            dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
        # 定义中间层
        # ResBlock + SpatialTransformer + ResBlock
        self.middle_block = TimestepEmbedSequential(
            ResBlock(
                ch,
                time_embed_dim,
                dropout,
                dims=dims,
                use_checkpoint=use_checkpoint,
                use_scale_shift_norm=use_scale_shift_norm,
            ),
            AttentionBlock(
                ch,
                use_checkpoint=use_checkpoint,
                num_heads=num_heads,
                num_head_channels=dim_head,
                use_new_attention_order=use_new_attention_order,
            ) if not use_spatial_transformer else SpatialTransformer(
                            ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim
                        ),
            ResBlock(
                ch,
                time_embed_dim,
                dropout,
                dims=dims,
                use_checkpoint=use_checkpoint,
                use_scale_shift_norm=use_scale_shift_norm,
            ),
        )
        self._feature_size += ch

        # 定义Unet上采样过程
        self.output_blocks = nn.ModuleList([])
        # 循环把channel_mult反了过来
        for level, mult in list(enumerate(channel_mult))[::-1]:
            # 上采样时每个部分循环三次
            for i in range(num_res_blocks + 1):
                ich = input_block_chans.pop()
                # 首先添加ResBlock层
                layers = [
                    ResBlock(
                        ch + ich,
                        time_embed_dim,
                        dropout,
                        out_channels=model_channels * mult,
                        dims=dims,
                        use_checkpoint=use_checkpoint,
                        use_scale_shift_norm=use_scale_shift_norm,
                    )
                ]
                ch = model_channels * mult
                # 然后进行SpatialTransformer自注意力
                if ds in attention_resolutions:
                    if num_head_channels == -1:
                        dim_head = ch // num_heads
                    else:
                        num_heads = ch // num_head_channels
                        dim_head = num_head_channels
                    if legacy:
                        #num_heads = 1
                        dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
                    layers.append(
                        AttentionBlock(
                            ch,
                            use_checkpoint=use_checkpoint,
                            num_heads=num_heads_upsample,
                            num_head_channels=dim_head,
                            use_new_attention_order=use_new_attention_order,
                        ) if not use_spatial_transformer else SpatialTransformer(
                            ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim
                        )
                    )
                # 如果不是channel_mult循环的第一个
                # 且
                # 是num_res_blocks循环的最后一次,则进行上采样
                if level and i == num_res_blocks:
                    out_ch = ch
                    layers.append(
                        ResBlock(
                            ch,
                            time_embed_dim,
                            dropout,
                            out_channels=out_ch,
                            dims=dims,
                            use_checkpoint=use_checkpoint,
                            use_scale_shift_norm=use_scale_shift_norm,
                            up=True,
                        )
                        if resblock_updown
                        else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch)
                    )
                    ds //= 2
                self.output_blocks.append(TimestepEmbedSequential(*layers))
                self._feature_size += ch

        # 最后在输出部分进行一次卷积
        self.out = nn.Sequential(
            normalization(ch),
            nn.SiLU(),
            zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)),
        )
        if self.predict_codebook_ids:
            self.id_predictor = nn.Sequential(
            normalization(ch),
            conv_nd(dims, model_channels, n_embed, 1),
            #nn.LogSoftmax(dim=1)  # change to cross_entropy and produce non-normalized logits
        )

    def convert_to_fp16(self):
        """
        Convert the torso of the model to float16.
        """
        self.input_blocks.apply(convert_module_to_f16)
        self.middle_block.apply(convert_module_to_f16)
        self.output_blocks.apply(convert_module_to_f16)

    def convert_to_fp32(self):
        """
        Convert the torso of the model to float32.
        """
        self.input_blocks.apply(convert_module_to_f32)
        self.middle_block.apply(convert_module_to_f32)
        self.output_blocks.apply(convert_module_to_f32)

    def forward(self, x, timesteps=None, context=None, y=None,**kwargs):
        """
        Apply the model to an input batch.
        :param x: an [N x C x ...] Tensor of inputs.
        :param timesteps: a 1-D batch of timesteps.
        :param context: conditioning plugged in via crossattn
        :param y: an [N] Tensor of labels, if class-conditional.
        :return: an [N x C x ...] Tensor of outputs.
        """
        assert (y is not None) == (
            self.num_classes is not None
        ), "must specify y if and only if the model is class-conditional"
        hs      = []
        # 用于计算当前采样时间t的embedding
        t_emb   = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
        emb     = self.time_embed(t_emb)

        if self.num_classes is not None:
            assert y.shape == (x.shape[0],)
            emb = emb + self.label_emb(y)

        # 对输入模块进行循环,进行下采样并且融合时间特征与文本特征。
        h = x.type(self.dtype)
        for module in self.input_blocks:
            h = module(h, emb, context)
            hs.append(h)

        # 中间模块的特征提取
        h = self.middle_block(h, emb, context)

        # 上采样模块的特征提取
        for module in self.output_blocks:
            h = th.cat([h, hs.pop()], dim=1)
            h = module(h, emb, context)
        h = h.type(x.dtype)
        # 输出模块
        if self.predict_codebook_ids:
            return self.id_predictor(h)
        else:
            return self.out(h)

3. Decodificación de espacio oculto para generar imágenes.

inserte la descripción de la imagen aquí
A través de los pasos anteriores, los resultados se pueden obtener mediante muestreo múltiple y luego podemos generar imágenes a través de la decodificación del espacio latente.

El proceso de decodificación del espacio latente para generar imágenes es muy sencillo. Utilice el método decode_first_stage para generar imágenes a partir del resultado del muestreo múltiple anterior.

En el método decode_first_stage, la red llama a VAE para decodificar el vector oculto obtenido de 64x64x3 para obtener una imagen de 512x512x3.

@torch.no_grad()
def decode_first_stage(self, z, predict_cids=False, force_not_quantize=False):
    if predict_cids:
        if z.dim() == 4:
            z = torch.argmax(z.exp(), dim=1).long()
        z = self.first_stage_model.quantize.get_codebook_entry(z, shape=None)
        z = rearrange(z, 'b h w c -> b c h w').contiguous()

    z = 1. / self.scale_factor * z
	# 一般无需分割输入,所以直接将x_noisy传入self.model中,在下面else进行
    if hasattr(self, "split_input_params"):
    	......
    else:
        if isinstance(self.first_stage_model, VQModelInterface):
            return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize)
        else:
            return self.first_stage_model.decode(z)

Código de proceso de predicción de texto a imagen

El código de predicción general es el siguiente:

import random

import einops
import numpy as np
import torch
import cv2
import os
from ldm_hacked import DDIMSampler
from ldm_hacked import create_model, load_state_dict, DDIMSampler
from pytorch_lightning import seed_everything

# ----------------------- #
#   使用的参数
# ----------------------- #
# config的地址
config_path = "model_data/sd_v15.yaml"
# 模型的地址
model_path  = "model_data/v1-5-pruned-emaonly.safetensors"

# 生成的图像大小为input_shape
input_shape = [512, 512]
# 一次生成几张图像
num_samples = 2
# 采样的步数
ddim_steps  = 20
# 采样的种子,为-1的话则随机。
seed        = 12345
# eta
eta         = 0

# 提示词
prompt      = "a cat"
# 正面提示词
a_prompt    = "best quality, extremely detailed"
# 负面提示词
n_prompt    = "longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality"
# 正负扩大倍数
scale       = 9

# save_path
save_path   = "imgs/outputs_imgs"

# ----------------------- #
#   创建模型
# ----------------------- #
model   = create_model(config_path).cpu()
model.load_state_dict(load_state_dict(model_path, location='cuda'), strict=False)
model   = model.cuda()
ddim_sampler = DDIMSampler(model)

with torch.no_grad():
    if seed == -1:
        seed = random.randint(0, 65535)
    seed_everything(seed)

    # ----------------------- #
    #   获得编码后的prompt
    # ----------------------- #
    cond    = {
    
    "c_crossattn": [model.get_learned_conditioning([prompt + ', ' + a_prompt] * num_samples)]}
    un_cond = {
    
    "c_crossattn": [model.get_learned_conditioning([n_prompt] * num_samples)]}
    H, W    = input_shape
    shape   = (4, H // 8, W // 8)

    # ----------------------- #
    #   进行采样
    # ----------------------- #
    samples, intermediates = ddim_sampler.sample(ddim_steps, num_samples,
                                                    shape, cond, verbose=False, eta=eta,
                                                    unconditional_guidance_scale=scale,
                                                    unconditional_conditioning=un_cond)

    # ----------------------- #
    #   进行解码
    # ----------------------- #
    x_samples = model.decode_first_stage(samples)
    x_samples = (einops.rearrange(x_samples, 'b c h w -> b h w c') * 127.5 + 127.5).cpu().numpy().clip(0, 255).astype(np.uint8)

# ----------------------- #
#   保存图片
# ----------------------- #
if not os.path.exists(save_path):
    os.makedirs(save_path)
for index, image in enumerate(x_samples):
    cv2.imwrite(os.path.join(save_path, str(index) + ".jpg"), cv2.cvtColor(image, cv2.COLOR_BGR2RGB))

Supongo que te gusta

Origin blog.csdn.net/weixin_44791964/article/details/130588215
Recomendado
Clasificación