확산 확산 모델 학습 3——안정적인 확산 구조 분석—이미지 생성 이미지(Picture Image, img2img)를 예로 들어

공부 서문

Stable Diffusion을 오랫동안 사용해 왔지만 내부 구조를 제대로 분석한 적이 없었습니다. 기록하기 위해 블로그를 작성하십시오, hehe.
여기에 이미지 설명 삽입

소스코드 다운로드 주소

https://github.com/bubbliiiing/stable-diffusion

당신이 그것을 좋아한다면 당신은 별을 주문할 수 있습니다.

네트워크 구축

1. 안정 확산(SD)이란?

Stable Diffusion은 비교적 새로운 확산 모델입니다.번역은 안정 확산입니다.이름은 안정 확산이지만 실제로 시드를 변경하여 생성된 결과는 완전히 다르고 매우 불안정합니다.

Stable Diffusion의 초기 적용은 텍스트, 즉 Vincent 그래프에서 이미지를 생성하는 것이어야 하며, 기술의 발전으로 Stable Diffusion은 image2image 그래프 생성을 지원할 뿐만 아니라 생성된 그래프를 사용자 정의할 수 있는 ControlNet과 같은 다양한 제어 방법을 지원합니다. 이미지.

Stable Diffusion은 확산 모델을 기반으로 하므로 필연적으로 지속적인 노이즈 제거 과정이 포함됩니다. DDPM과 비교하여 Stable
여기에 이미지 설명 삽입
Diffusion DDIM 샘플러를 사용하고 latent space의 확산을 사용하며 pre-training에는 매우 큰 LAION-5B 데이터셋을 사용한다.

Direct Finetune Stable Diffusion 대부분의 학생들은 비용을 감당할 수 없지만 Stable Diffusion은 Lora, Textual Inversion 등과 같은 경량 Finetune 솔루션이 많이 있지만 이는 나중에 이야기합니다.

이 글은 주로 전체 SD 모델의 구조적 구성, 1회 확산과 다중 확산 과정을 분석한다.

대형 모델과 AIGC는 현재 업계 트렌드입니다. 방법을 모르면 쉽게 도태됩니다.


txt2img의 원리 는 Diffusion Diffusion Model Learning 2——Stable Diffusion Structure Analysis—Taking Text Generated Image(txt2img) 블로그 게시물에 예시되어 있습니다
.

2. 안정적인 확산의 구성

안정적인 확산은 네 부분으로 구성됩니다.
1. 샘플러 샘플러.
2. VAE(Variational Autoencoder) 변이 자동 인코더.
3. UNet 메인 네트워크, 노이즈 예측기.
4. CLIPE mbedder 텍스트 인코더.

모든 부분이 매우 중요하므로 이미지 생성을 예로 들어 분석해 보겠습니다. 이미지를 생성하는 것은 이미지이기 때문에 두 개의 입력이 있습니다. 하나는 텍스트이고 다른 하나는 그림입니다.

셋, img2img 생성 과정

여기에 이미지 설명 삽입
생성 프로세스는 네 부분으로 나뉩니다.
1. 그림에서 VAE 인코딩을 수행하고 노이즈 제거 값에 따라 노이즈를 추가합니다.
2. 프롬프트 텍스트 인코딩.
3. 노이즈 제거 값에 따라 여러 샘플을 수행합니다.
4. VAE를 사용하여 디코딩합니다.

Wensheng 그래프와 비교하여 그래프 생성 그래프의 입력이 변경되었으며 더 이상 가우시안 노이즈로 초기화되지 않고 노이즈를 추가한 후 이미지 특징 으로 정보가 모델에 주입됩니다 . 영상.

자세히 살펴보면 위의 그림과 같이

  • 첫 번째 단계는 입력 이미지에 VAE 인코딩을 사용하여 입력 이미지의 Latent 특성을 얻은 다음 Latent 특성을 사용하여 DDIM Sampler를 기반으로 노이즈를 추가하고 이때 입력 이미지의 노이즈 특성을 얻는 것입니다. 노이즈 제거 값을 0.8로 설정하고 총 단계 수가 20이라고 가정하면 첫 번째 단계에서 입력 이미지에 노이즈를 0.8x20배 추가하고 나머지 4단계는 추가하지 않는 것으로 이해할 수 있습니다. 이미지의 80%를 방해하는 것으로 특징은 특징의 20%를 유지합니다.
  • 두 번째 단계는 입력 텍스트를 인코딩하여 텍스트 특징을 얻는 것입니다.
  • 세 번째 단계는 첫 번째 단계에서 얻은 잡음 특징을 노이즈 제거 값에 따라 여러 번 샘플링하는 것 입니다. 여전히 첫 번째 단계에서 노이즈 제거 값 0.8을 예로 들어 0.8x20배의 노이즈만 추가한 다음 그림을 복원하기 위해 0.8x20배의 샘플링만 수행하면 됩니다 .
  • 네 번째 단계는 VAE의 디코더를 사용하여 샘플링된 이미지를 복원하는 것입니다.
with torch.no_grad():
    if seed == -1:
        seed = random.randint(0, 65535)
    seed_everything(seed)

    # ----------------------- #
    #   对输入图片进行编码并加噪
    # ----------------------- #
    if image_path is not None:
        img = HWC3(np.array(img, np.uint8))
        img = torch.from_numpy(img.copy()).float().cuda() / 127.0 - 1.0
        img = torch.stack([img for _ in range(num_samples)], dim=0)
        img = einops.rearrange(img, 'b h w c -> b c h w').clone()

        ddim_sampler.make_schedule(ddim_steps, ddim_eta=eta, verbose=True)
        t_enc = min(int(denoise_strength * ddim_steps), ddim_steps - 1)
        z = model.get_first_stage_encoding(model.encode_first_stage(img))
        z_enc = ddim_sampler.stochastic_encode(z, torch.tensor([t_enc] * num_samples).to(model.device))

    # ----------------------- #
    #   获得编码后的prompt
    # ----------------------- #
    cond    = {
    
    "c_crossattn": [model.get_learned_conditioning([prompt + ', ' + a_prompt] * num_samples)]}
    un_cond = {
    
    "c_crossattn": [model.get_learned_conditioning([n_prompt] * num_samples)]}
    H, W    = input_shape
    shape   = (4, H // 8, W // 8)

    if image_path is not None:
        samples = ddim_sampler.decode(z_enc, cond, t_enc, unconditional_guidance_scale=scale, unconditional_conditioning=un_cond)
    else:
        # ----------------------- #
        #   进行采样
        # ----------------------- #
        samples, intermediates = ddim_sampler.sample(ddim_steps, num_samples,
                                                        shape, cond, verbose=False, eta=eta,
                                                        unconditional_guidance_scale=scale,
                                                        unconditional_conditioning=un_cond)

    # ----------------------- #
    #   进行解码
    # ----------------------- #
    x_samples = model.decode_first_stage(samples)
    x_samples = (einops.rearrange(x_samples, 'b c h w -> b h w c') * 127.5 + 127.5).cpu().numpy().clip(0, 255).astype(np.uint8)

1. 이미지 코드 입력

여기에 이미지 설명 삽입

이미지 생성 다이어그램에서 먼저 참조 이미지를 지정한 다음 이 참조 이미지에 대한 작업을 시작해야 합니다.
1. VAE 인코더를 사용하여 이 참조 이미지를 인코딩하여 잠재 공간에 들어갈 때만 잠재 공간에 들어가도록 합니다. 2. 그런 다음 Latent 기능
을 사용하여 DDIM Sampler를 기반으로 노이즈를 추가한 다음 입력 이미지의 노이즈가 있는 기능을 얻습니다. 노이즈를 추가하는 논리는 다음과 같습니다.

  • 노이즈 제거는 재구성 비율로 간주할 수 있으며, 1은 모든 재구성을 나타내고, 0은 단계적 재구성을 나타냅니다.
  • 노이즈 제거 값을 0.8로 설정하고 총 단계 수는 20이라고 가정하면 입력 이미지에 노이즈를 0.8x20번 추가하고 나머지 4단계는 추가하지 않으므로 특성의 80%로 이해할 수 있습니다. 그러나 20단계의 노이즈를 추가한 후에도 원래 입력 이미지의 정보는 여전히 어느 정도 보존되어 있으며 완전히 보존되지 않은 것은 아닙니다 .

이때 latent space 에 노이즈를 추가한 후 이미지를 얻고 , 이 latent space 에 노이즈를 추가한 후 이미지를 기반으로 샘플링을 수행합니다 .

2. 텍스트 인코딩

여기에 이미지 설명 삽입
텍스트 인코딩의 아이디어는 비교적 간단합니다. CLIP 텍스트 인코더를 사용하여 직접 인코딩하면 됩니다. 코드에 FrozenCLIPEmbedder 카테고리가 정의되어 있으며 변환기 라이브러리의 CLIPTokenizer 및 CLIPTextModel이 사용됩니다.

사전 전송 과정에서 먼저 CLIPTokenizer를 사용하여 입력 텍스트를 인코딩한 다음 CLIPTextModel을 사용하여 특징 추출을 수행하고 FrozenCLIPEmbedder를 통해 [batch_size, 77, 768]의 특징 벡터를 얻을 수 있습니다.

class FrozenCLIPEmbedder(AbstractEncoder):
    """Uses the CLIP transformer encoder for text (from huggingface)"""
    LAYERS = [
        "last",
        "pooled",
        "hidden"
    ]
    def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77,
                 freeze=True, layer="last", layer_idx=None):  # clip-vit-base-patch32
        super().__init__()
        assert layer in self.LAYERS
        # 定义文本的tokenizer和transformer
        self.tokenizer      = CLIPTokenizer.from_pretrained(version)
        self.transformer    = CLIPTextModel.from_pretrained(version)
        self.device         = device
        self.max_length     = max_length
        # 冻结模型参数
        if freeze:
            self.freeze()
        self.layer = layer
        self.layer_idx = layer_idx
        if layer == "hidden":
            assert layer_idx is not None
            assert 0 <= abs(layer_idx) <= 12

    def freeze(self):
        self.transformer = self.transformer.eval()
        # self.train = disabled_train
        for param in self.parameters():
            param.requires_grad = False

    def forward(self, text):
        # 对输入的图片进行分词并编码,padding直接padding到77的长度。
        batch_encoding  = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,
                                        return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
        # 拿出input_ids然后传入transformer进行特征提取。
        tokens          = batch_encoding["input_ids"].to(self.device)
        outputs         = self.transformer(input_ids=tokens, output_hidden_states=self.layer=="hidden")
        # 取出所有的token
        if self.layer == "last":
            z = outputs.last_hidden_state
        elif self.layer == "pooled":
            z = outputs.pooler_output[:, None, :]
        else:
            z = outputs.hidden_states[self.layer_idx]
        return z

    def encode(self, text):
        return self(text)

3. 샘플링 프로세스

여기에 이미지 설명 삽입

a. 초기 노이즈 생성

이미지 생성 이미지에서 우리의 초기 노이즈는 참조 이미지에서 얻으므로 첫 번째 단계를 참조하여 이미지 생성 이미지의 노이즈를 얻을 수 있습니다.

b.노이즈 샘플링 N번

Stable Diffusion은 연속적인 확산 과정이기 때문에 지속적인 노이즈 제거가 필수이므로 어떻게 노이즈를 제거하느냐가 문제입니다.

이전 단계에서 우리는 정규 분포를 따르는 벡터인 img를 얻었고 여기에서 노이즈 제거를 시작합니다.

이제 노이즈를 추가하는 대신 노이즈를 제거하기 때문에 ddim_timesteps의 시간 단계를 반전한 다음 루프를 수행합니다. 루프 코드는 다음과 같습니다.

루프에는 로컬 재구성에 사용되는 마스크가 있으며 여기에서는 사용되지 않는 일부 영역의 숨겨진 벡터를 마스크합니다 . 다른 모든 것은 메서드 또는 함수이며 아무것도 볼 수 없습니다. 샘플링 프로세스와 가장 유사한 것은 p_sample_ddim 메서드입니다. 확인하려면 p_sample_ddim 메서드를 입력해야 합니다.

for i, step in enumerate(iterator):
    # index是用来取得对应的调节参数的
    index   = total_steps - i - 1
    # 将步数拓展到bs维度
    ts      = torch.full((b,), step, device=device, dtype=torch.long)

    # 用于进行局部的重建,对部分区域的隐向量进行mask。
    if mask is not None:
        assert x0 is not None
        img_orig = self.model.q_sample(x0, ts)  # TODO: deterministic forward pass?
        img = img_orig * mask + (1. - mask) * img

    # 进行采样
    outs = self.p_sample_ddim(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps,
                                quantize_denoised=quantize_denoised, temperature=temperature,
                                noise_dropout=noise_dropout, score_corrector=score_corrector,
                                corrector_kwargs=corrector_kwargs,
                                unconditional_guidance_scale=unconditional_guidance_scale,
                                unconditional_conditioning=unconditional_conditioning)
    img, pred_x0 = outs
    # 回调函数
    if callback: callback(i)
    if img_callback: img_callback(pred_x0, i)

    if index % log_every_t == 0 or index == total_steps - 1:
        intermediates['x_inter'].append(img)
        intermediates['pred_x0'].append(pred_x0)

여기에 이미지 설명 삽입

c. 단일 샘플링 분석

I. 예측 잡음

단어 샘플링을 하기 전에 먼저 neg 프롬프트가 있는지 판단해야 하고, 그렇다면 neg 프롬프트를 동시에 처리해야 하고, 그렇지 않으면 pos 프롬프트만 처리하면 됩니다. 실제 사용에서는 일반적으로 부정 프롬프트(효과가 더 좋을 것임)가 있으므로 해당 처리 프로세스가 기본적으로 입력됩니다.

neg 프롬프트를 처리할 때 입력된 숨겨진 벡터와 단계 번호를 복사합니다. 하나는 pos 프롬프트에 속하고 다른 하나는 neg 프롬프트에 속합니다. torch.cat의 기본 stacking 차원은 0이므로 batch_size 차원에 쌓이게 되며, 둘은 서로 영향을 주지 않습니다. 그런 다음 pos 프롬프트와 neg 프롬프트를 배치로 쌓는데, 이 역시 batch_size 차원에 쌓입니다.

# 首先判断是否由neg prompt,unconditional_conditioning是由neg prompt获得的
if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
    e_t = self.model.apply_model(x, t, c)
else:
    # 一般都是有neg prompt的,所以进入到这里
    # 在这里我们对隐向量和步数进行复制,一个属于pos prompt,一个属于neg prompt
    # torch.cat默认堆叠维度为0,所以是在bs维度进行堆叠,二者不会互相影响
    x_in = torch.cat([x] * 2)
    t_in = torch.cat([t] * 2)
    # 然后我们将pos prompt和neg prompt堆叠到一个batch中
    if isinstance(c, dict):
        assert isinstance(unconditional_conditioning, dict)
        c_in = dict()
        for k in c:
            if isinstance(c[k], list):
                c_in[k] = [
                    torch.cat([unconditional_conditioning[k][i], c[k][i]])
                    for i in range(len(c[k]))
                ]
            else:
                c_in[k] = torch.cat([unconditional_conditioning[k], c[k]])
    else:
        c_in = torch.cat([unconditional_conditioning, c])

여기에 이미지 설명 삽입
스태킹 후 히든 벡터, 단계 번호, 프롬프트 조건을 함께 네트워크에 전달하고 결과를 청크를 사용하여 bs 차원으로 나눕니다.

쌓을 때 neg 프롬프트가 앞에 놓이기 때문입니다. 따라서 분할 후 전반부는 e_t_uncondneg 프롬프트를 이용하여 획득하고, 후반부는 e_tpos 프롬프트를 이용하여 획득하게 되는데, 본질적으로 pos 프롬프트의 영향력을 확대하고 neg 프롬프트의 영향을 멀리해야 한다. . 따라서 우리는 e_t-e_t_uncond둘 사이의 거리를 계산하기 위해 를 사용하고 둘 사이의 거리를 확장하기 위해 스케일을 사용합니다. e_t_uncond를 기반으로 최종 숨겨진 벡터를 얻습니다.

# 堆叠完后,隐向量、步数和prompt条件一起传入网络中,将结果在bs维度进行使用chunk进行分割
e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)

여기에 이미지 설명 삽입
이때 구한 e_t는 은닉 벡터와 프롬프트가 합동으로 구한 예측 잡음이다.

II. 노이즈 적용

노이즈를 잡아도 괜찮나요? 분명히 아닙니다. 우리는 또한 원래의 원래 노이즈에 특정 비율로 얻은 새로운 노이즈를 추가해야 합니다.

여기서는 ddim의 공식을 결합하는 것이 좋습니다. α ˉ t \bar{\alpha}_t를 얻어야 합니다.ˉα ˉ t − 1 \bar{\alpha}_{t-1}ˉt - 1σ t \sigma_t1 − α ˉ t \sqrt{1-\bar{\alpha}_t}1-ˉ ..
여기에 이미지 설명 삽입
여기에 이미지 설명 삽입
코드에서 실제로 이러한 매개변수를 미리 계산했습니다. 직접 꺼내기만 하면 되는데, 아래의 a_t는 수식에서 괄호 밖의 α ˉ t \bar{\alpha}_t 입니다.ˉ, a_prev는 공식에서 α ˉ t − 1 \bar{\alpha}_{t-1} 입니다.ˉt - 1, sigma_t는 공식에서 σ t \sigma_t 입니다., sqrt_one_minus_at는 공식에서 1 − α ˉ t \sqrt{1-\bar{\alpha}_t} 입니다.1-ˉ .

# 根据采样器选择参数
alphas      = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
sigmas      = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas

# 根据步数选择参数,
# 这里的index就是上面循环中的total_steps - i - 1
a_t         = torch.full((b, 1, 1, 1), alphas[index], device=device)
a_prev      = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
sigma_t     = torch.full((b, 1, 1, 1), sigmas[index], device=device)
sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)

사실, 이 단계에서는 수식이 후속 덧셈, 뺄셈, 곱셈 및 나눗셈을 용이하게 하기 위해 사용해야 하는 모든 계수를 제거했습니다. 그런 다음 위 공식을 코드로 구현합니다.

# current prediction for x_0
# 公式中的最左边
pred_x0             = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
if quantize_denoised:
    pred_x0, _, *_  = self.model.first_stage_model.quantize(pred_x0)
# direction pointing to x_t
# 公式的中间
dir_xt              = (1. - a_prev - sigma_t**2).sqrt() * e_t
# 公式最右边
noise               = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
if noise_dropout > 0.:
    noise           = torch.nn.functional.dropout(noise, p=noise_dropout)
x_prev              = a_prev.sqrt() * pred_x0 + dir_xt + noise
# 输出添加完公式的结果
return x_prev, pred_x0

여기에 이미지 설명 삽입

d. 노이즈 예측 과정에서 네트워크 구조 분석

I. apply_model 방식 분석

3.a에서 노이즈를 예측하는 과정에서 노이즈를 예측하기 위해 model.apply_model 메소드를 사용하는데, 이 메소드가 하는 일은 숨겨져 있습니다. 구체적인 작업을 살펴보겠습니다.

apply_model 메서드는 ldm.models.diffusion.ddpm.py 파일에 있습니다. apply_model에서 x_noisy를 self.model에 전달하여 노이즈를 예측합니다.

x_recon = self.model(x_noisy, t, **cond)

여기에 이미지 설명 삽입
self.model은 Stable Diffusion의 Unet 네트워크를 포함하는 ldm.models.diffusion.ddpm.py 파일의 1416행에 정의된 미리 빌드된 클래스입니다.self.model의 기능은 wrapper와 다소 유사합니다. 모델에 따라 선택 특징 융합 방법은 위에서 생성된 텍스트와 노이즈를 융합하는 데 사용됩니다.

c_concat은 스태킹을 사용한 융합을 나타내고 c_crossattn은 어텐션을 사용한 융합을 나타냅니다.

class DiffusionWrapper(pl.LightningModule):
    def __init__(self, diff_model_config, conditioning_key):
        super().__init__()
        self.sequential_cross_attn = diff_model_config.pop("sequential_crossattn", False)
        # stable diffusion的unet网络
        self.diffusion_model = instantiate_from_config(diff_model_config)
        self.conditioning_key = conditioning_key
        assert self.conditioning_key in [None, 'concat', 'crossattn', 'hybrid', 'adm', 'hybrid-adm', 'crossattn-adm']

    def forward(self, x, t, c_concat: list = None, c_crossattn: list = None, c_adm=None):
        if self.conditioning_key is None:
            out = self.diffusion_model(x, t)
        elif self.conditioning_key == 'concat':
            xc = torch.cat([x] + c_concat, dim=1)
            out = self.diffusion_model(xc, t)
        elif self.conditioning_key == 'crossattn':
            if not self.sequential_cross_attn:
                cc = torch.cat(c_crossattn, 1)
            else:
                cc = c_crossattn
            out = self.diffusion_model(x, t, context=cc)
        elif self.conditioning_key == 'hybrid':
            xc = torch.cat([x] + c_concat, dim=1)
            cc = torch.cat(c_crossattn, 1)
            out = self.diffusion_model(xc, t, context=cc)
        elif self.conditioning_key == 'hybrid-adm':
            assert c_adm is not None
            xc = torch.cat([x] + c_concat, dim=1)
            cc = torch.cat(c_crossattn, 1)
            out = self.diffusion_model(xc, t, context=cc, y=c_adm)
        elif self.conditioning_key == 'crossattn-adm':
            assert c_adm is not None
            cc = torch.cat(c_crossattn, 1)
            out = self.diffusion_model(x, t, context=cc, y=c_adm)
        elif self.conditioning_key == 'adm':
            cc = c_crossattn[0]
            out = self.diffusion_model(x, t, y=cc)
        else:
            raise NotImplementedError()

        return out

여기에 이미지 설명 삽입
코드의 self.diffusion_model은 Stable Diffusion의 Unet 네트워크이며 네트워크 구조는 ldm.modules.diffusionmodules.openaimodel.py 파일의 UNetModel 클래스에 있습니다.

II.UNetModel 모델 분석

UNetModel의 주요 작업은 시간 단계 t와 텍스트 임베딩을 결합하여 이 순간의 노이즈를 계산하는 것입니다 . UNet의 아이디어는 매우 간단하지만 StableDiffusion에서 UNetModel은 ResBlock과 Transformer 모듈로 구성되어 있어 전체적으로 일반 UNet보다 복잡합니다.

Prompt는 Frozen CLIP Text Encoder를 통해 Text Embedding을 획득하고 Timesteps는 전체 연결(MLP)을 통해 Timesteps Embedding을 획득합니다.

ResBlock은 시간 단계 Timesteps Embedding을 결합하는 데 사용되고 Transformer 모듈은 Text Embedding을 결합하는 데 사용됩니다.

여기에 큰 그림을 올려 놓고 학생들은 내부 모양의 변화를 볼 수 있습니다.
사진 설명을 추가해주세요

Unet 코드는 다음과 같습니다.

class UNetModel(nn.Module):
    """
    The full UNet model with attention and timestep embedding.
    :param in_channels: channels in the input Tensor.
    :param model_channels: base channel count for the model.
    :param out_channels: channels in the output Tensor.
    :param num_res_blocks: number of residual blocks per downsample.
    :param attention_resolutions: a collection of downsample rates at which
        attention will take place. May be a set, list, or tuple.
        For example, if this contains 4, then at 4x downsampling, attention
        will be used.
    :param dropout: the dropout probability.
    :param channel_mult: channel multiplier for each level of the UNet.
    :param conv_resample: if True, use learned convolutions for upsampling and
        downsampling.
    :param dims: determines if the signal is 1D, 2D, or 3D.
    :param num_classes: if specified (as an int), then this model will be
        class-conditional with `num_classes` classes.
    :param use_checkpoint: use gradient checkpointing to reduce memory usage.
    :param num_heads: the number of attention heads in each attention layer.
    :param num_heads_channels: if specified, ignore num_heads and instead use
                               a fixed channel width per attention head.
    :param num_heads_upsample: works with num_heads to set a different number
                               of heads for upsampling. Deprecated.
    :param use_scale_shift_norm: use a FiLM-like conditioning mechanism.
    :param resblock_updown: use residual blocks for up/downsampling.
    :param use_new_attention_order: use a different attention pattern for potentially
                                    increased efficiency.
    """

    def __init__(
        self,
        image_size,
        in_channels,
        model_channels,
        out_channels,
        num_res_blocks,
        attention_resolutions,
        dropout=0,
        channel_mult=(1, 2, 4, 8),
        conv_resample=True,
        dims=2,
        num_classes=None,
        use_checkpoint=False,
        use_fp16=False,
        num_heads=-1,
        num_head_channels=-1,
        num_heads_upsample=-1,
        use_scale_shift_norm=False,
        resblock_updown=False,
        use_new_attention_order=False,
        use_spatial_transformer=False,    # custom transformer support
        transformer_depth=1,              # custom transformer support
        context_dim=None,                 # custom transformer support
        n_embed=None,                     # custom support for prediction of discrete ids into codebook of first stage vq model
        legacy=True,
    ):
        super().__init__()
        if use_spatial_transformer:
            assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...'

        if context_dim is not None:
            assert use_spatial_transformer, 'Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...'
            from omegaconf.listconfig import ListConfig
            if type(context_dim) == ListConfig:
                context_dim = list(context_dim)

        if num_heads_upsample == -1:
            num_heads_upsample = num_heads

        if num_heads == -1:
            assert num_head_channels != -1, 'Either num_heads or num_head_channels has to be set'

        if num_head_channels == -1:
            assert num_heads != -1, 'Either num_heads or num_head_channels has to be set'

        self.image_size = image_size
        self.in_channels = in_channels
        self.model_channels = model_channels
        self.out_channels = out_channels
        self.num_res_blocks = num_res_blocks
        self.attention_resolutions = attention_resolutions
        self.dropout = dropout
        self.channel_mult = channel_mult
        self.conv_resample = conv_resample
        self.num_classes = num_classes
        self.use_checkpoint = use_checkpoint
        self.dtype = th.float16 if use_fp16 else th.float32
        self.num_heads = num_heads
        self.num_head_channels = num_head_channels
        self.num_heads_upsample = num_heads_upsample
        self.predict_codebook_ids = n_embed is not None

        # 用于计算当前采样时间t的embedding
        time_embed_dim  = model_channels * 4
        self.time_embed = nn.Sequential(
            linear(model_channels, time_embed_dim),
            nn.SiLU(),
            linear(time_embed_dim, time_embed_dim),
        )

        if self.num_classes is not None:
            self.label_emb = nn.Embedding(num_classes, time_embed_dim)
        
        # 定义输入模块的第一个卷积
        # TimestepEmbedSequential也可以看作一个包装器,根据层的种类进行时间或者文本的融合。
        self.input_blocks = nn.ModuleList(
            [
                TimestepEmbedSequential(
                    conv_nd(dims, in_channels, model_channels, 3, padding=1)
                )
            ]
        )
        self._feature_size  = model_channels
        input_block_chans   = [model_channels]
        ch                  = model_channels
        ds                  = 1
        # 对channel_mult进行循环,channel_mult一共有四个值,代表unet四个部分通道的扩张比例
        # [1, 2, 4, 4]
        for level, mult in enumerate(channel_mult):
            # 每个部分循环两次
            # 添加一个ResBlock和一个AttentionBlock
            for _ in range(num_res_blocks):
                # 先添加一个ResBlock
                # 用于对输入的噪声进行通道数的调整,并且融合t的特征
                layers = [
                    ResBlock(
                        ch,
                        time_embed_dim,
                        dropout,
                        out_channels=mult * model_channels,
                        dims=dims,
                        use_checkpoint=use_checkpoint,
                        use_scale_shift_norm=use_scale_shift_norm,
                    )
                ]
                # ch便是上述ResBlock的输出通道数
                ch = mult * model_channels
                if ds in attention_resolutions:
                    # num_heads=8
                    if num_head_channels == -1:
                        dim_head = ch // num_heads
                    else:
                        num_heads = ch // num_head_channels
                        dim_head = num_head_channels
                    if legacy:
                        #num_heads = 1
                        dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
                    # 使用了SpatialTransformer自注意力,加强全局特征,融合文本的特征
                    layers.append(
                        AttentionBlock(
                            ch,
                            use_checkpoint=use_checkpoint,
                            num_heads=num_heads,
                            num_head_channels=dim_head,
                            use_new_attention_order=use_new_attention_order,
                        ) if not use_spatial_transformer else SpatialTransformer(
                            ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim
                        )
                    )
                self.input_blocks.append(TimestepEmbedSequential(*layers))
                self._feature_size += ch
                input_block_chans.append(ch)
            # 如果不是四个部分中的最后一个部分,那么都要进行下采样。
            if level != len(channel_mult) - 1:
                out_ch = ch
                # 在此处进行下采样
                # 一般直接使用Downsample模块
                self.input_blocks.append(
                    TimestepEmbedSequential(
                        ResBlock(
                            ch,
                            time_embed_dim,
                            dropout,
                            out_channels=out_ch,
                            dims=dims,
                            use_checkpoint=use_checkpoint,
                            use_scale_shift_norm=use_scale_shift_norm,
                            down=True,
                        )
                        if resblock_updown
                        else Downsample(
                            ch, conv_resample, dims=dims, out_channels=out_ch
                        )
                    )
                )
                # 为下一阶段定义参数。
                ch = out_ch
                input_block_chans.append(ch)
                ds *= 2
                self._feature_size += ch

        if num_head_channels == -1:
            dim_head = ch // num_heads
        else:
            num_heads = ch // num_head_channels
            dim_head = num_head_channels
        if legacy:
            #num_heads = 1
            dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
        # 定义中间层
        # ResBlock + SpatialTransformer + ResBlock
        self.middle_block = TimestepEmbedSequential(
            ResBlock(
                ch,
                time_embed_dim,
                dropout,
                dims=dims,
                use_checkpoint=use_checkpoint,
                use_scale_shift_norm=use_scale_shift_norm,
            ),
            AttentionBlock(
                ch,
                use_checkpoint=use_checkpoint,
                num_heads=num_heads,
                num_head_channels=dim_head,
                use_new_attention_order=use_new_attention_order,
            ) if not use_spatial_transformer else SpatialTransformer(
                            ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim
                        ),
            ResBlock(
                ch,
                time_embed_dim,
                dropout,
                dims=dims,
                use_checkpoint=use_checkpoint,
                use_scale_shift_norm=use_scale_shift_norm,
            ),
        )
        self._feature_size += ch

        # 定义Unet上采样过程
        self.output_blocks = nn.ModuleList([])
        # 循环把channel_mult反了过来
        for level, mult in list(enumerate(channel_mult))[::-1]:
            # 上采样时每个部分循环三次
            for i in range(num_res_blocks + 1):
                ich = input_block_chans.pop()
                # 首先添加ResBlock层
                layers = [
                    ResBlock(
                        ch + ich,
                        time_embed_dim,
                        dropout,
                        out_channels=model_channels * mult,
                        dims=dims,
                        use_checkpoint=use_checkpoint,
                        use_scale_shift_norm=use_scale_shift_norm,
                    )
                ]
                ch = model_channels * mult
                # 然后进行SpatialTransformer自注意力
                if ds in attention_resolutions:
                    if num_head_channels == -1:
                        dim_head = ch // num_heads
                    else:
                        num_heads = ch // num_head_channels
                        dim_head = num_head_channels
                    if legacy:
                        #num_heads = 1
                        dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
                    layers.append(
                        AttentionBlock(
                            ch,
                            use_checkpoint=use_checkpoint,
                            num_heads=num_heads_upsample,
                            num_head_channels=dim_head,
                            use_new_attention_order=use_new_attention_order,
                        ) if not use_spatial_transformer else SpatialTransformer(
                            ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim
                        )
                    )
                # 如果不是channel_mult循环的第一个
                # 且
                # 是num_res_blocks循环的最后一次,则进行上采样
                if level and i == num_res_blocks:
                    out_ch = ch
                    layers.append(
                        ResBlock(
                            ch,
                            time_embed_dim,
                            dropout,
                            out_channels=out_ch,
                            dims=dims,
                            use_checkpoint=use_checkpoint,
                            use_scale_shift_norm=use_scale_shift_norm,
                            up=True,
                        )
                        if resblock_updown
                        else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch)
                    )
                    ds //= 2
                self.output_blocks.append(TimestepEmbedSequential(*layers))
                self._feature_size += ch

        # 最后在输出部分进行一次卷积
        self.out = nn.Sequential(
            normalization(ch),
            nn.SiLU(),
            zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)),
        )
        if self.predict_codebook_ids:
            self.id_predictor = nn.Sequential(
            normalization(ch),
            conv_nd(dims, model_channels, n_embed, 1),
            #nn.LogSoftmax(dim=1)  # change to cross_entropy and produce non-normalized logits
        )

    def convert_to_fp16(self):
        """
        Convert the torso of the model to float16.
        """
        self.input_blocks.apply(convert_module_to_f16)
        self.middle_block.apply(convert_module_to_f16)
        self.output_blocks.apply(convert_module_to_f16)

    def convert_to_fp32(self):
        """
        Convert the torso of the model to float32.
        """
        self.input_blocks.apply(convert_module_to_f32)
        self.middle_block.apply(convert_module_to_f32)
        self.output_blocks.apply(convert_module_to_f32)

    def forward(self, x, timesteps=None, context=None, y=None,**kwargs):
        """
        Apply the model to an input batch.
        :param x: an [N x C x ...] Tensor of inputs.
        :param timesteps: a 1-D batch of timesteps.
        :param context: conditioning plugged in via crossattn
        :param y: an [N] Tensor of labels, if class-conditional.
        :return: an [N x C x ...] Tensor of outputs.
        """
        assert (y is not None) == (
            self.num_classes is not None
        ), "must specify y if and only if the model is class-conditional"
        hs      = []
        # 用于计算当前采样时间t的embedding
        t_emb   = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
        emb     = self.time_embed(t_emb)

        if self.num_classes is not None:
            assert y.shape == (x.shape[0],)
            emb = emb + self.label_emb(y)

        # 对输入模块进行循环,进行下采样并且融合时间特征与文本特征。
        h = x.type(self.dtype)
        for module in self.input_blocks:
            h = module(h, emb, context)
            hs.append(h)

        # 中间模块的特征提取
        h = self.middle_block(h, emb, context)

        # 上采样模块的特征提取
        for module in self.output_blocks:
            h = th.cat([h, hs.pop()], dim=1)
            h = module(h, emb, context)
        h = h.type(x.dtype)
        # 输出模块
        if self.predict_codebook_ids:
            return self.id_predictor(h)
        else:
            return self.out(h)

4. 그림 생성을 위한 숨겨진 공간 디코딩

여기에 이미지 설명 삽입
위의 단계를 통해 다중 샘플링을 통해 결과를 얻을 수 있으며 잠재 공간 디코딩을 통해 그림을 생성할 수 있습니다.

잠재 공간 복호화 과정은 매우 간단하며 decode_first_stage 메서드를 사용하여 위의 다중 샘플링 결과에서 그림을 생성합니다.

decode_first_stage 메서드에서 네트워크는 VAE를 호출하여 획득한 숨겨진 벡터 64x64x3을 디코딩하여 512x512x3의 그림을 얻습니다.

@torch.no_grad()
def decode_first_stage(self, z, predict_cids=False, force_not_quantize=False):
    if predict_cids:
        if z.dim() == 4:
            z = torch.argmax(z.exp(), dim=1).long()
        z = self.first_stage_model.quantize.get_codebook_entry(z, shape=None)
        z = rearrange(z, 'b h w c -> b c h w').contiguous()

    z = 1. / self.scale_factor * z
	# 一般无需分割输入,所以直接将x_noisy传入self.model中,在下面else进行
    if hasattr(self, "split_input_params"):
    	......
    else:
        if isinstance(self.first_stage_model, VQModelInterface):
            return self.first_stage_model.decode(z, force_not_quantize=predict_cids or force_not_quantize)
        else:
            return self.first_stage_model.decode(z)

이미지 대 이미지 예측 프로세스 코드

전체 예측 코드는 다음과 같습니다.

import os
import random

import cv2
import einops
import numpy as np
import torch
from PIL import Image
from pytorch_lightning import seed_everything

from ldm_hacked import *

# ----------------------- #
#   使用的参数
# ----------------------- #
# config的地址
config_path = "model_data/sd_v15.yaml"
# 模型的地址
model_path  = "model_data/v1-5-pruned-emaonly.safetensors"
# fp16,可以加速与节省显存
sd_fp16     = True
vae_fp16    = True

# ----------------------- #
#   生成图片的参数
# ----------------------- #
# 生成的图像大小为input_shape,对于img2img会进行Centter Crop
input_shape = [512, 512]
# 一次生成几张图像
num_samples = 1
# 采样的步数
ddim_steps  = 20
# 采样的种子,为-1的话则随机。
seed        = 12345
# eta
eta         = 0
# denoise强度,for img2img
denoise_strength = 1.0

# ----------------------- #
#   提示词相关参数
# ----------------------- #
# 提示词
prompt      = "a cute cat, with yellow leaf, trees"
# 正面提示词
a_prompt    = "best quality, extremely detailed"
# 负面提示词
n_prompt    = "longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality"
# 正负扩大倍数
scale       = 9
# img2img使用,如果不想img2img这设置为None。
image_path  = None

# ----------------------- #
#   保存路径
# ----------------------- #
save_path   = "imgs/outputs_imgs"

# ----------------------- #
#   创建模型
# ----------------------- #
model   = create_model(config_path).cpu()
model.load_state_dict(load_state_dict(model_path, location='cuda'), strict=False)
model   = model.cuda()
ddim_sampler = DDIMSampler(model)
if sd_fp16:
    model = model.half()

if image_path is not None:
    img = Image.open(image_path)
    img = crop_and_resize(img, input_shape[0], input_shape[1])

with torch.no_grad():
    if seed == -1:
        seed = random.randint(0, 65535)
    seed_everything(seed)
    
    # ----------------------- #
    #   对输入图片进行编码并加噪
    # ----------------------- #
    if image_path is not None:
        img = HWC3(np.array(img, np.uint8))
        img = torch.from_numpy(img.copy()).float().cuda() / 127.0 - 1.0
        img = torch.stack([img for _ in range(num_samples)], dim=0)
        img = einops.rearrange(img, 'b h w c -> b c h w').clone()
        if vae_fp16:
            img = img.half()
            model.first_stage_model = model.first_stage_model.half()
        else:
            model.first_stage_model = model.first_stage_model.float()

        ddim_sampler.make_schedule(ddim_steps, ddim_eta=eta, verbose=True)
        t_enc = min(int(denoise_strength * ddim_steps), ddim_steps - 1)
        z = model.get_first_stage_encoding(model.encode_first_stage(img))
        z_enc = ddim_sampler.stochastic_encode(z, torch.tensor([t_enc] * num_samples).to(model.device))
        z_enc = z_enc.half() if sd_fp16 else z_enc.float()

    # ----------------------- #
    #   获得编码后的prompt
    # ----------------------- #
    cond    = {
    
    "c_crossattn": [model.get_learned_conditioning([prompt + ', ' + a_prompt] * num_samples)]}
    un_cond = {
    
    "c_crossattn": [model.get_learned_conditioning([n_prompt] * num_samples)]}
    H, W    = input_shape
    shape   = (4, H // 8, W // 8)

    if image_path is not None:
        samples = ddim_sampler.decode(z_enc, cond, t_enc, unconditional_guidance_scale=scale, unconditional_conditioning=un_cond)
    else:
        # ----------------------- #
        #   进行采样
        # ----------------------- #
        samples, intermediates = ddim_sampler.sample(ddim_steps, num_samples,
                                                        shape, cond, verbose=False, eta=eta,
                                                        unconditional_guidance_scale=scale,
                                                        unconditional_conditioning=un_cond)

    # ----------------------- #
    #   进行解码
    # ----------------------- #
    x_samples = model.decode_first_stage(samples.half() if vae_fp16 else samples.float())

    x_samples = (einops.rearrange(x_samples, 'b c h w -> b h w c') * 127.5 + 127.5).cpu().numpy().clip(0, 255).astype(np.uint8)

# ----------------------- #
#   保存图片
# ----------------------- #
if not os.path.exists(save_path):
    os.makedirs(save_path)
for index, image in enumerate(x_samples):
    cv2.imwrite(os.path.join(save_path, str(index) + ".jpg"), cv2.cvtColor(image, cv2.COLOR_BGR2RGB))

Supongo que te gusta

Origin blog.csdn.net/weixin_44791964/article/details/131992399
Recomendado
Clasificación