DDPM模型——pytorch实现


论文传送门:Denoising Diffusion Probabilistic Models

参考文章:The Annotated Diffusion Model

DDPM的目的:

从标准正态分布中采样出噪声图像,经过T次去噪后还原出与训练图像相似的生成图像,从而完成图像生成任务。

DDPM的方法:

Alt

①扩散过程(加噪过程):

对训练图像不断加噪,经过T次,使得训练图像近似变成各向独立的标准正态分布的噪声图像。
每次加噪记作 q ( x t ∣ x t − 1 ) q(x_t|x_{t-1}) q(xtxt1),其中t指当前时刻(加噪t次),t-1指上一时刻(加噪t-1次), x t x_t xt指当前时刻的图像, x t − 1 x_{t-1} xt1指上一时刻的图像。
整个过程是马尔科夫链,即当前时刻的图像仅与其上一时刻有关,而与其他时刻无关。
设定一个长度为T的序列 β β β β t β_t βt在(0,1)区间内单调递增,t时刻加入噪声的方差为 β t β_t βt,均值由 β t β_t βt x t x_t xt共同决定,则可以写出当前时刻 q ( x t ∣ x t − 1 ) q(x_t|x_{t-1}) q(xtxt1)和整个扩散过程 q ( x 1 : T ∣ x 0 ) q(x_{1:T}|x_0) q(x1:Tx0)的公式:

Alt
可以发现,任意时刻的噪声图像 x t x_t xt可以由初始时刻图像(原图) x 0 x_0 x0 β β β序列来确定,定义 α t = 1 − β t α_t=1-β_t αt=1βt α ˉ = ∏ s = 1 t α s \bar{α}=\prod\limits_{s=1}^tα_s αˉ=s=1tαs,则:
Alt
当T趋近于∞时,可以认为 x T x_T xT是各向独立的标准正态分布。
扩散过程与网络无关,只要确定初始时刻图像 x 0 x_0 x0 β β β序列,整个扩散过程均可求。

②逆扩散过程(去噪过程):

对噪声图像不断去噪,经过T次,使得噪声图像可以恢复为初始时刻图像 x 0 x_0 x0
​每次去噪记作 p ( x t − 1 ∣ x t ) p(x_{t-1}|x_t) p(xt1xt),公式:
Alt
p ( x t − 1 ∣ x t ) p(x_{t-1}|x_t) p(xt1xt)难以直接求解,所以使用网络进行计算。
使用网络来拟合分布有多种实现方式,作者选择构建网络 D θ ( x t , t ) D_θ(x_t,t) Dθ(xt,t)( θ θ θ表示网络参数),通过输入xt和t,输出t时刻噪声 z t z_t zt(而非直接预测 x t − 1 x_{t-1} xt1),同时假设噪声 z t z_t zt的方差确定,在损失计算时仅计算 μ θ μ_θ μθ
损失函数:
生成网络的目的是使得生成图像分布 p θ ( x 0 ) p_θ(x_0) pθ(x0)尽可能接近于原始图像分布 q ( x 0 ) q(x_0) q(x0),计算其KL散度 K L ( q ∣ ∣ p θ ) KL(q||p_θ) KL(q∣∣pθ),可以写出负对数似然函数的上界,通过最小化其上界达到最大化似然函数的目的:
Alt
用KL散度的形式表达,可以写成下式:
Alt
其中,第一项 L T L_T LT不含模型参数 θ θ θ,为常数,与模型优化过程无关;第二项 L t − 1 L_{t-1} Lt1与第三项 L 0 L_0 L0可以进行展开化简( L 0 L_0 L0可以看作当 t = 1 t=1 t=1时的 L t − 1 L_{t-1} Lt1)。
注意到 q ( x t − 1 ∣ x t , x 0 ) q(x_{t-1}|x_t,x_0) q(xt1xt,x0)可以计算,使用贝叶斯公式重参数技巧可以进行化简:
Alt
损失函数 L t − 1 L_{t-1} Lt1(允许 t = 1 t=1 t=1以包含 L 0 L_0 L0项):
Alt
Alt
Alt
Alt
作者将系数去掉,得到简化的损失函数Loss:
Alt

L 0 L_0 L0的讨论:

可以发现,当x的值在(-1,1)区间时,网络才有梯度可以更新,所以作者将输入图片数据从[0,255]压缩至[-1,1]来确保反向传播过程的正常进行。
Alt

DDPM的训练与采样过程:

Alt

DDPM的结构:

基于U-Net网络,加入了位置编码、残差结构、注意力机制和组标准化等模块。

train.py

import os

import torch
from torch.utils.data import DataLoader

import torch.nn.functional as F
import matplotlib.pyplot as plt
from tqdm.auto import tqdm

import numpy as np
from torchvision import transforms, datasets

from model import Unet  # DDPM模型


# 定义4种生成β的方法,均需传入总步长T,返回β序列
def cosine_beta_schedule(timesteps, s=0.008):
    steps = timesteps + 1
    x = torch.linspace(0, timesteps, steps)
    alphas_cumprod = torch.cos(((x / timesteps) + s) / (1 + s) * torch.pi * 0.5) ** 2
    alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
    betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
    return torch.clip(betas, 0.0001, 0.9999)


def linear_beta_schedule(timesteps):
    beta_start = 0.0001
    beta_end = 0.02
    return torch.linspace(beta_start, beta_end, timesteps)


def quadratic_beta_schedule(timesteps):
    beta_start = 0.0001
    beta_end = 0.02
    return torch.linspace(beta_start ** 0.5, beta_end ** 0.5, timesteps) ** 2


def sigmoid_beta_schedule(timesteps):
    beta_start = 0.0001
    beta_end = 0.02
    betas = torch.linspace(-6, 6, timesteps)
    return torch.sigmoid(betas) * (beta_end - beta_start) + beta_start


# 从序列a中取t时刻的值a[t](batch_size个),维度与x_shape相同,第一维为batch_size
def extract(a, t, x_shape):
    batch_size = t.shape[0]
    out = a.gather(-1, t.cpu())
    return out.reshape(batch_size, *((1,) * (len(x_shape) - 1))).to(t.device)


# 扩散过程采样,即通过x0和t计算xt
def q_sample(x_start, t, noise=None):
    if noise is None:
        noise = torch.randn_like(x_start)
    sqrt_alphas_cumprod_t = extract(sqrt_alphas_cumprod, t, x_start.shape)
    sqrt_one_minus_alphas_cumpord_t = extract(sqrt_one_minus_alphas_cumpord, t, x_start.shape)
    return sqrt_alphas_cumprod_t * x_start + sqrt_one_minus_alphas_cumpord_t * noise


# 损失函数loss,共3种计算方式,原文使用l2
def p_losses(denoise_model, x_start, t, noise=None, loss_type="l1"):
    if noise is None:
        noise = torch.randn_like(x_start)

    x_noisy = q_sample(x_start, t, noise)
    predicted_noise = denoise_model(x_noisy, t)

    if loss_type == "l1":
        loss = F.l1_loss(noise, predicted_noise)
    elif loss_type == "l2":
        loss = F.mse_loss(noise, predicted_noise)
    elif loss_type == "huber":
        loss = F.smooth_l1_loss(noise, predicted_noise)
    else:
        raise NotImplementedError()

    return loss


# 逆扩散过程采样,即通过xt和t计算xt-1,此过程需要通过网络
@torch.no_grad()
def p_sample(model, x, t, t_index):
    betas_t = extract(betas, t, x.shape)
    sqrt_one_minus_alphas_cumpord_t = extract(sqrt_one_minus_alphas_cumpord, t, x.shape)
    sqrt_recip_alphas_t = extract(sqrt_recip_alphas, t, x.shape)

    model_mean = sqrt_recip_alphas_t * (x - betas_t * model(x, t) / sqrt_one_minus_alphas_cumpord_t)
    if t_index == 0:
        return model_mean
    else:
        posterior_variance_t = extract(posterior_variance, t, x.shape)
        noise = torch.randn_like(x)
        return model_mean + torch.sqrt(posterior_variance_t) * noise


# 逆扩散过程T次采样,即通过xT和T计算xi,获得每一个时刻的图像列表[xi],此过程需要通过网络
@torch.no_grad()
def p_sample_loop(model, shape):
    device = next(model.parameters()).device
    b = shape[0]
    img = torch.randn(shape, device=device)
    imgs = []
    for i in tqdm(reversed(range(0, timesteps)), desc="sampling loop time step", total=timesteps):
        img = p_sample(model, img, torch.full((b,), i, device=device, dtype=torch.long), i)
        imgs.append(img.cpu())
    return imgs


# 逆扩散过程T次采样,允许传入batch_size指定生成图片的个数,用于生成结果的可视化
@torch.no_grad()
def sample(model, image_size, batch_size=16, channels=1):
    return p_sample_loop(model, shape=(batch_size, channels, image_size, image_size))


if __name__ == "__main__":
    timesteps = 300  # 总步长T
    # 以下参数均为序列(List),需要传入t获得对应t时刻的值 xt = X[t]
    betas = linear_beta_schedule(timesteps=timesteps)  # 选择一种方式,生成β(t)
    alphas = 1. - betas  # α(t)
    alphas_cumprod = torch.cumprod(alphas, axis=0)  # α的连乘序列,对应α_bar(t)
    alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0),
                                value=1.0)  # 将α_bar的最后一个值删除,在最开始添加1,对应前一个时刻的α_bar,即α_bar(t-1)
    sqrt_recip_alphas = torch.sqrt(1. / alphas)  # 1/根号下α(t)
    sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod)  # 根号下α_bar(t)
    sqrt_one_minus_alphas_cumpord = torch.sqrt(1. - alphas_cumprod)  # 根号下(1-α_bar(t))
    posterior_variance = betas * (1. - alphas_cumprod_prev) / (
                1. - alphas_cumprod)  # β(t)x(1-α_bar(t-1))/(1-α_bar(t)),即β^~(t)

    total_epochs = 10
    image_size = 28
    channels = 1
    batch_size = 256
    lr = 1e-3

    os.makedirs("../dataset/mnist", exist_ok=True)
    os.makedirs("images", exist_ok=True)

    transform = transforms.Compose([
        transforms.Resize(image_size),
        transforms.CenterCrop(image_size),
        transforms.ToTensor(),
        transforms.Lambda(lambda t: (t * 2) - 1)  # 此处将输入数据从(0,1)区间转换到(-1,1)区间
    ])
    dataset = datasets.MNIST(root="../dataset/mnist", train=True, transform=transform, download=True)

    reverse_transform = transforms.Compose([  # tensor转换为PIL图片
        transforms.Lambda(lambda t: (t + 1) / 2),
        transforms.Lambda(lambda t: t.permute(1, 2, 0)),
        transforms.Lambda(lambda t: t * 255),
        transforms.Lambda(lambda t: t.numpy().astype(np.uint8)),
        transforms.ToPILImage()
    ])

    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

    device = "cuda" if torch.cuda.is_available() else "cpu"
    model = Unet(dim=image_size, channels=channels, dim_mults=(1, 2, 4))
    model.to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)

    for epoch in range(total_epochs):
        total_loss = 0
        pbar = tqdm(total=len(dataloader), desc=f"Epoch {
      
      epoch + 1}/{
      
      total_epochs}", postfix=dict,
                    miniters=0.3)
        for iter, (img, _) in enumerate(dataloader):
            img = img.to(device)
            optimizer.zero_grad()
            batch_size = img.shape[0]

            t = torch.randint(0, timesteps, (batch_size,), device=device).long()

            loss = p_losses(model, img, t, loss_type="huber")  # 选择loss计算方式,计算loss

            loss.backward()
            optimizer.step()
            total_loss += loss.item()

            pbar.set_postfix(**{
    
    "Loss": loss.item()})
            pbar.update(1)
        pbar.close()
        print("total_loss:%.4f" %
              (total_loss / len(dataloader)))

        # 展示一张图片的生成过程(去噪过程),每3步生成一张图片,共100张图片(在一幅图中展示)
        val_images = sample(model, image_size, batch_size=1, channels=channels)
        fig, axs = plt.subplots(10, 10, figsize=(20, 20))
        plt.rc("text", color="blue")
        for t in range(100):
            i = t // 10
            j = t % 10
            image = val_images[t * 3 + 2].squeeze(0)
            image = reverse_transform(image)
            axs[i, j].imshow(image, cmap="gray")
            axs[i, j].set_axis_off()
            axs[i, j].set_title("$q(\mathbf{x}_{" + str(300 - 3 * t - 3) + "})$")
        plt.savefig(f"images/{
      
      epoch + 1}.png", bbox_inches='tight')
        plt.close()

model.py

import math
from inspect import isfunction
from functools import partial

from einops import rearrange, reduce
from einops.layers.torch import Rearrange

import torch
from torch import nn, einsum
import torch.nn.functional as F


def exisit(x):
    return x is not None


def default(val, d):
    if exisit(val):
        return val
    return d() if isfunction(d) else d


class Residual(nn.Module):
    def __init__(self, fn):
        super(Residual, self).__init__()
        self.fn = fn

    def forward(self, x, *args, **kwargs):
        return self.fn(x, *args, **kwargs) + x


def Upsample(dim, dim_out=None):
    return nn.Sequential(
        nn.Upsample(scale_factor=2, mode="nearest"),
        nn.Conv2d(dim, default(dim_out, dim), 3, 1, 1)
    )


def Downsample(dim, dim_out=None):
    return nn.Sequential(
        Rearrange("b c (h p1) (w p2) -> b (c p1 p2) h w", p1=2, p2=2),
        nn.Conv2d(dim * 4, default(dim_out, dim), 1, 1, 0)
    )


class SinusoidalPositionEmbeddings(nn.Module):
    def __init__(self, dim):
        super(SinusoidalPositionEmbeddings, self).__init__()
        self.dim = dim

    def forward(self, time):
        device = time.device
        half_dim = self.dim // 2
        embeddings = math.log(10000) / (half_dim - 1)
        embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings)
        embeddings = time[:, None] * embeddings[None, :]
        embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=1)
        return embeddings


class WeightStandardizedConv2d(nn.Conv2d):
    def forward(self, x):
        eps = 1e-5 if x.dtype == torch.float32 else 1e-3

        weight = self.weight
        mean = reduce(weight, "o ... -> o 1 1 1", "mean")
        var = reduce(weight, "o ... -> o 1 1 1", partial(torch.var, unbiased=False))
        normalized_weight = (weight - mean) * (var + eps).rsqrt()

        return F.conv2d(
            x,
            normalized_weight,
            self.bias,
            self.stride,
            self.padding,
            self.dilation,
            self.groups
        )


class Block(nn.Module):
    def __init__(self, dim, dim_out, groups=8):
        super(Block, self).__init__()
        self.proj = WeightStandardizedConv2d(dim, dim_out, 3, padding=1)
        self.norm = nn.GroupNorm(groups, dim_out)
        self.act = nn.SiLU()

    def forward(self, x, scale_shift=None):
        x = self.proj(x)
        x = self.norm(x)

        if exisit(scale_shift):
            scale, shift = scale_shift
            x = x * (scale + 1) + shift

        x = self.act(x)

        return x


class ResnetBlock(nn.Module):
    def __init__(self, dim, dim_out, *, time_emb_dim=None, groups=8):
        super(ResnetBlock, self).__init__()
        self.mlp = (
            nn.Sequential(nn.SiLU(), nn.Linear(time_emb_dim, dim_out * 2)) if exisit(time_emb_dim) else None
        )
        self.block1 = Block(dim, dim_out, groups=groups)
        self.block2 = Block(dim_out, dim_out, groups=groups)
        self.res_conv = nn.Conv2d(dim, dim_out, 1, 1, 0) if dim != dim_out else nn.Identity()

    def forward(self, x, time_emb=None):
        scale_shift = None
        if exisit(self.mlp) and exisit(time_emb):
            time_emb = self.mlp(time_emb)
            time_emb = rearrange(time_emb, "b c -> b c 1 1")
            scale_shift = time_emb.chunk(2, dim=1)

        h = self.block1(x, scale_shift=scale_shift)
        h = self.block2(h)
        return h + self.res_conv(x)


class Attention(nn.Module):
    def __init__(self, dim, heads=4, dim_head=32):
        super(Attention, self).__init__()
        self.scale = dim_head ** -0.5
        self.heads = heads
        hidden_dim = dim_head * heads
        self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, 1, 0, bias=False)
        self.to_out = nn.Conv2d(hidden_dim, dim, 1, 1, 0)

    def forward(self, x):
        b, c, h, w = x.shape
        qkv = self.to_qkv(x).chunk(3, dim=1)
        q, k, v = map(
            lambda t: rearrange(t, "b (h c) x y -> b h c (x y)", h=self.heads), qkv
        )
        q = q * self.scale

        sim = einsum("b h d i, b h d j -> b h i j", q, k)
        sim = sim - sim.amax(dim=-1, keepdim=True).detach()
        attn = sim.softmax(dim=-1)

        out = einsum("b h i j, b h d j -> b h i d", attn, v)
        out = rearrange(out, "b h (x y) d -> b (h d) x y", x=h, y=w)
        return self.to_out(out)


class LinearAttention(nn.Module):
    def __init__(self, dim, heads=4, dim_head=32):
        super(LinearAttention, self).__init__()
        self.scale = dim_head ** -0.5
        self.heads = heads
        hidden_dim = dim_head * heads
        self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, 1, 0, bias=False)
        self.to_out = nn.Sequential(
            nn.Conv2d(hidden_dim, dim, 1, 1, 0),
            nn.GroupNorm(1, dim)
        )

    def forward(self, x):
        b, c, h, w = x.shape
        qkv = self.to_qkv(x).chunk(3, dim=1)
        q, k, v = map(
            lambda t: rearrange(t, "b (h c) x y -> b h c (x y)", h=self.heads), qkv
        )

        q = q.softmax(dim=-2)
        k = k.softmax(dim=-1)

        q = q * self.scale
        context = torch.einsum("b h d n, b h e n -> b h d e", k, v)

        out = torch.einsum("b h d e, b h d n -> b h e n", context, q)
        out = rearrange(out, "b h c (x y) -> b (h c) x y", h=self.heads, x=h, y=w)
        return self.to_out(out)


class PreNorm(nn.Module):
    def __init__(self, dim, fn):
        super(PreNorm, self).__init__()
        self.fn = fn
        self.norm = nn.GroupNorm(1, dim)

    def forward(self, x):
        x = self.norm(x)
        return self.fn(x)


class Unet(nn.Module):
    def __init__(self, dim, init_dim=None, out_dim=None, dim_mults=(1, 2, 4, 8), channels=3, self_condition=False,
                 resnet_block_groups=4):
        super(Unet, self).__init__()

        self.channels = channels
        self.self_condition = self_condition
        input_channels = channels * (2 if self_condition else 1)

        init_dim = default(init_dim, dim)
        self.init_conv = nn.Conv2d(input_channels, init_dim, 1, 1, 0)

        dims = [init_dim, *map(lambda m: dim * m, dim_mults)]
        in_out = list(zip(dims[:-1], dims[1:]))

        block_klass = partial(ResnetBlock, groups=resnet_block_groups)

        time_dim = dim * 4

        self.time_mlp = nn.Sequential(
            SinusoidalPositionEmbeddings(dim),
            nn.Linear(dim, time_dim),
            nn.GELU(),
            nn.Linear(time_dim, time_dim)
        )

        self.downs = nn.ModuleList([])
        self.ups = nn.ModuleList([])
        num_resolutions = len(in_out)
        for ind, (dim_in, dim_out) in enumerate(in_out):
            is_last = ind >= (num_resolutions - 1)

            self.downs.append(
                nn.ModuleList([
                    block_klass(dim_in, dim_in, time_emb_dim=time_dim),
                    block_klass(dim_in, dim_in, time_emb_dim=time_dim),
                    Residual(PreNorm(dim_in, LinearAttention(dim_in))),
                    Downsample(dim_in, dim_out) if not is_last else nn.Conv2d(dim_in, dim_out, 3, 1, 1)
                ])
            )

        mid_dim = dims[-1]
        self.mid_block1 = block_klass(mid_dim, mid_dim, time_emb_dim=time_dim)
        self.mid_attn = Residual(PreNorm(mid_dim, Attention(mid_dim)))
        self.mid_block2 = block_klass(mid_dim, mid_dim, time_emb_dim=time_dim)

        for ind, (dim_in, dim_out) in enumerate(reversed(in_out)):
            is_last = ind == (len(in_out) - 1)
            self.ups.append(
                nn.ModuleList([
                    block_klass(dim_out + dim_in, dim_out, time_emb_dim=time_dim),
                    block_klass(dim_out + dim_in, dim_out, time_emb_dim=time_dim),
                    Residual(PreNorm(dim_out, LinearAttention(dim_out))),
                    Upsample(dim_out, dim_in) if not is_last else nn.Conv2d(dim_out, dim_in, 3, 1, 1)
                ])
            )

        self.out_dim = default(out_dim, channels)

        self.final_res_block = block_klass(dim * 2, dim, time_emb_dim=time_dim)
        self.final_conv = nn.Conv2d(dim, self.out_dim, 1, 1, 0)

    def forward(self, x, time, x_self_cond=None):
        if self.self_condition:
            x_self_cond = default(x_self_cond, lambda: torch.zeros_like(x))
            x = torch.cat((x_self_cond, x), dim=1)

        x = self.init_conv(x)
        r = x.clone()

        t = self.time_mlp(time)

        h = []

        for block1, block2, attn, downsample in self.downs:
            x = block1(x, t)
            h.append(x)
            x = block2(x, t)
            x = attn(x)
            h.append(x)
            x = downsample(x)

        x = self.mid_block1(x, t)
        x = self.mid_attn(x)
        x = self.mid_block2(x, t)

        for block1, block2, attn, upsample in self.ups:
            x = torch.cat((x, h.pop()), dim=1)
            x = block1(x, t)
            x = torch.cat((x, h.pop()), dim=1)
            x = block2(x, t)
            x = attn(x)
            x = upsample(x)

        x = torch.cat((x, r), dim=1)
        x = self.final_res_block(x, t)
        return self.final_conv(x)


if __name__ == "__main__":
    model = Unet(28)
    print(model)

生成图像(去噪过程)示意图:
Alt

猜你喜欢

转载自blog.csdn.net/Peach_____/article/details/128663957
今日推荐