[AI Theory Learning] Using PyTorch to realize the diffusion model DDPM


The previous blog post has deduced the diffusion model DDPM by hand. This article uses PyTorch to implement the diffusion model in Google Colab notebook.

DDPM code implementation

Note that there are several perspectives on the diffusion model. Here we take the discrete-time (latent variable model) perspective, but be sure to check out the other perspectives.

Neural Networks

A neural network needs to take in a noisy image at a specific time step and return a predicted noise. Note that the predicted noise is a tensor with the same size/resolution as the input image . So technically the network receives and outputs tensors with the same shape. In this case, what type of neural network can be used?

The approach generally used here is very similar to an Autoencoder, which you may remember from typical "intro to deep learning gates" tutorials. Autoencoders have a so-called "bottleneck" layer between the encoder and decoder. The encoder first encodes the image into a smaller hidden representation, called a "bottleneck", and the decoder decodes that hidden representation back into the actual image. This forces the network to keep only the most important information in bottleneck layers.

In terms of architecture, the authors of DDPM adopted a U-Net structure introduced by (Ronneberger et al., 2015), which achieved state-of-the-art results in medical image segmentation at that time. This network, like any autoencoder, consists of a bottleneck layer in the middle, ensuring that the network learns only the most important information. Importantly, it introduces a residual connection between the encoder and decoder, which greatly improves gradient flow (inspired by ResNet in He et al., 2015).
U-net
As shown, the U-Net model first downsamples the input (i.e. makes the input smaller in terms of spatial resolution) and then upsamples it .

Next, we implement this network step by step.

!pip install -q -U einops datasets matplotlib tqdm

Import related dependent libraries

import math
from inspect import isfunction
from functools import partial

%matplotlib inline
import matplotlib.pyplot as plt
from tqdm.auto import tqdm
from einops import rearrange, reduce
from einops.layers.torch import Rearrange

import torch
from torch import nn, einsum
import torch.nn.functional as F

define helper functions

First, define some helper functions and classes that will be used when implementing the neural network . Importantly, a residual module is defined that simply adds the input to the output of a specific function (in other words, adds a residual connection to a specific function).

def exists(x):
    return x is not None


def default(val, d):
    if exists(val):
        return val
    return d() if isfunction(d) else d


def num_to_groups(num, divisor):
    groups = num // divisor
    remainder = num % divisor
    arr = [divisor] * groups
    if remainder > 0:
        arr.append(remainder)
    return arr


class Residual(nn.Module):
    def __init__(self, fn):
        super().__init__()
        self.fn = fn

    def forward(self, x, *args, **kwargs):
        return self.fn(x, *args, **kwargs) + x

We also define aliases for upsampling and downsampling operations.

def Upsample(dim, dim_out=None):
    return nn.Sequential(
        nn.Upsample(scale_factor=2, mode='nearest'),
        nn.Conv2d(dim, default(dim_out, dim), 3, padding=1),
    )


def Downsample(dim, dim_out=None):
    # 不再有阶梯卷积或池
    return nn.Sequential(
        Rearrange("b c (h p1) (w p2) -> b (c p1 p2) h w", p1=2, p2=2),
        nn.Conv2d(dim * 4, default(dim_out, dim), 1),
    )

position embedding

Since the parameters of the neural network are shared between different times (noise levels), the authors adopt sinusoidal position embeddings (sinusoidal position embeddings) inspired by Transformer (Vaswani et al., 2017) to encode ttt . This allows the neural network to "know" that it is processing a specific time step (noise level) for each image in the batch.

SinusoidalPositionEmbeddings(batch_size,1)The module accepts as input a tensor of shape (i.e. the noise level of multiple noisy images in the batch) and converts it to (batch_size,dim)a tensor of shape , where dimis the dimension of the positional embedding. This is then added to each residual block, as we will see later.

class SinusoidalPositionEmbeddings(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim

    def forward(self, time):
        device = time.device
        half_dim = self.dim // 2
        embeddings = math.log(10000) / (half_dim - 1)
        embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings)
        embeddings = time[:, None] * embeddings[None, :]
        embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1)
        return embeddings

In short, the ttt is encoded as embedding, and enters the network together with the original input, so that the network "knows" which step the current input belongs to.

ResNet block

Next, define the core building blocks of the U-Net model. The authors of DDPM used Wide ResNet block( Zagoruyko et al., 2016 ), but Phil Wang replaced the standard convolutional layer with a " weight standardized" version, which works better with group normalization (see ( Kolesnikov et al. ., 2019 )).

class WeightStandardizedConv2d(nn.Conv2d):
    """
    https://arxiv.org/abs/1903.10520
    weight standardization purportedly works synergistically with group normalization
    """
    def forward(self, x):
        eps = 1e-5 if x.dtype == torch.float32 else 1e-3

        weight = self.weight
        mean = reduce(weight, "o ... -> o 1 1 1", "mean")
        var = reduce(weight, "o ... -> o 1 1 1", partial(torch.var, unbiased=False))
        normalized_weight = (weight - mean) * (var + eps).rsqrt()

        return F.conv2d(
            x,
            normalized_weight,
            self.bias,
            self.strides,
            self.padding,
            self.dilation,
            self.groups,
        )

class Block(nn.Module):
    def __init__(self, dim, dim_out, groups=8):
        super().__init__()
        self.proj = WeightStandardizedConv2d(dim, dim_out, 3, padding=1)
        self.norm = nn.GroupNorm(groups, dim_out)
        self.act = nn.SiLU()

    def forward(self, x, scale_shift=None):
        x = self.proj(x)
        x = self.norm(x)

        if exists(scale_shift):
            scale, shift = scale_shift
            x = x * (scale + 1) + shift

        x = self.act(x)
        return x


class ResnetBlock(nn.Module):
    """https://arxiv.org/abs/1512.03385"""
    def __init__(self, dim, dim_out, *, time_emb_dim=None, groups=8):
        super().__init__()
        self.mlp = (
            nn.Sequential(nn.SiLU(), nn.Linear(time_emb_dim, dim_out * 2))
            if exists(time_emb_dim)
            else None
        )

        self.block1 = Block(dim, dim_out, groups=groups)
        self.block2 = Block(dim_out, dim_out, groups=groups)
        self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity()

    def forward(self, x, time_emb = None):
        scale_shift = None
        if exists(self.mlp) and exists(time_emb):
            time_emb = self.mlp(time_emb)
            time_emb = rearrange(time_emb, "b c -> b c 1 1")
            scale_shift = time_emb.chunk(2, dim=1)

        h = self.block1(x, scale_shift=scale_shift)
        h = self.block2(h)
        return h + self.res_conv(x)

attention module

Now, define the attention module, which the authors of DDPM add between the convolutional blocks . Attention is the building block of the famous Transformer architecture (Vaswani et al., 2017), which has achieved great success in various fields of artificial intelligence, from natural language processing and vision to protein folding. Phil Wang uses two variants of attention: one is regular multi-head self-attention (as used in Transformer), and the other is a linear attention variant ( linear attention variant) ( Shen et al., 2018 ), whose time and memory requirements scale linearly with sequence length, rather than quadratically scaling with regular attention.

For a detailed explanation of the attention mechanism, see Jay Allamar's excellent blog post .

class Attention(nn.Module):
    def __init__(self, dim, heads=4, dim_head=32):
        super().__init__()
        self.scale = dim_head ** -0.5
        self.heads = heads
        hidden_dim = dim_head * heads
        self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)
        self.to_out = nn.Conv2d(hidden_dim, dim, 1)

    def forward(self, x):
        b, c, h, w = x.shape
        qkv = self.to_qkv(x).chunk(3, dim=1)
        q, k, v = map(
            lambda t: rearrange(t, "b (h c) x y -> b h c (x y)", h=self.heads), qkv
        )
        q = q * self.scale

        sim = einsum("b h d i, b h d j -> b h i j", q, k)
        sim = sim - sim.amax(dim=-1, keepdim=True).detach()
        attn = sim.softmax(dim=-1)

        out = einsum("b h i j, b h d j -> b h i d", attn, v)
        out = rearrange(out, "b h (x y) d -> b (h d) x y", x=h, y=w)
        return self.to_out(out)


class LinearAttention(nn.Module):
    def __init__(self, dim, heads=4, dim_head=32):
        super().__init__()
        self.scale = dim_head ** -0.5
        self.heads = heads
        hidden_dim = dim_head * heads
        self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)
        self.to_out = nn.Sequential(nn.Conv2d(hidden_dim, dim, 1), nn.GroupNorm(1, dim))

    def forward(self, x):
        b, c, h, w = x.shape
        qkv = self.to_qkv(x).chunk(3, dim=1)
        q, k, v = map(
            lambda t: rearrange(t, "b (h c) x y -> b h c (x y)", h=self.heads), qkv
        )
        q = q.softmax(dim=2)
        k = k.softmax(dim=-1)

        q = q * self.scale
        context = torch.einsum("b h d n, b h e n -> b h d e", k, v)

        out = torch.einsum("b h d e, b h d n -> b h e n", context, q)
        out = rearrange(out, "b h c (x y) -> b (h c) x y", h=self.heads, x=h, y=w)
        return self.to_out(out)

group normalization

The authors of DDPM used group normalization ( group normalization) interleaved between the convolutional/attention layers of U-Net ( Wu et al., 2018 ). Below, a PreNorm class is defined that will apply group normalization before the attention layer , as we will see later. It is worth noting that there has been debate about whether to apply normalization before or after attention in Transformer.

class PreNorm(nn.Module):
    def __init__(self, dim, fn):
        super().__init__()
        self.fn = fn
        self.norm = nn.GroupNorm(1, dim)

    def forward(self, x):
        x = self.norm(x)
        return self.fn(x)

Conditional U-Net

Now that we have defined all the building blocks (position embeddings, ResNet blocks, attention and group normalization), it is time to define the entire neural network. Recall that the network ϵ θ ( xt , t ) \boldsymbol{\epsilon}_\theta\left(\mathbf{x_t}, t\right)ϵi(xt,t ) 's job isto take a batch of noisy images and their respective noise levels, and output the noise added to the input. More formally:

  • The network takes a batch of noise images of shape (batch_size, num_channels, height, width) and a batch of noise levels of shape (batch_size, 1) as input, and returns a sheet of shape (batch_size, num_channels, height, width) quantity

The network is constructed as follows:

  • First, apply a convolutional layer on a batch of noisy images and compute position embeddings of the noise level
  • Then, a series of downsampling stages are performed. Each downsampling stage consists of 2 ResNet blocks+ groupnorm+ attention+residual connection+a downsample operation
  • In the middle of the network, the ResNet block is applied again, interleaved with attention
  • Next, a series of upsampling stages are performed. Each upsampling stage consists of 2 ResNet blocks + groupnorm +attention + residual connection + an upsample operation
  • Finally, a ResNet block is applied after a convolutional layer.

Ultimately, neural networks are stacked like Lego bricks (but it's important to understand how they work).

class Unet(nn.Module):
    def __init__(self, dim, init_dim=None, out_dim=None, dim_mults=(1, 2, 4, 8), channels=3, self_condition=False,
                 resnet_block_groups=4):
        super().__init__()
        # determine dimensions
        self.channels = channels
        self.self_condition = self_condition
        input_channels = channels * (2 if self_condition else 1)

        init_dim = default(init_dim, dim)
        self.init_conv = nn.Conv2d(input_channels, init_dim, 1, padding=0)  #  changed to 1 and 0 from 7,3

        dims = [init_dim, *map(lambda m: dim * m, dim_mults)]
        in_out = list(zip(dims[:-1], dims[1:]))

        block_klass = partial(ResnetBlock, groups=resnet_block_groups)

        # time embeddings
        time_dim = dim * 4
        self.time_mlp = nn.Sequential(
            SinusoidalPositionEmbeddings(dim),
            nn.Linear(dim, time_dim),
            nn.GELU(),
            nn.Linear(time_dim, time_dim),
        )

        # layers
        self.downs = nn.ModuleList([])
        self.ups = nn.ModuleList([])
        num_resolutions = len(in_out)

        for ind, (dim_in, dim_out) in enumerate(in_out):
            is_last = ind >= (num_resolutions - 1)
            self.downs.append(
                nn.ModuleList(
                    [
                        block_klass(dim_in, dim_in, time_emb_dim=time_dim),
                        block_klass(dim_in, dim_in, time_emb_dim=time_dim),
                        Residual(PreNorm(dim_in, LinearAttention(dim_in))),
                        Downsample(dim_in, dim_out)
                        if not is_last
                        else nn.Conv2d(dim_in, dim_out, 3, padding=1),
                    ]
                )
            )
        mid_dim = dims[-1]
        self.mid_block1 = block_klass(mid_dim, mid_dim, time_emb_dim=time_dim)
        self.mid_attn = Residual(PreNorm(mid_dim, Attention(mid_dim)))
        self.mid_block2 = block_klass(mid_dim, mid_dim, time_emb_dim=time_dim)

        for ind, (dim, dim_out) in enumerate(reversed(in_out)):
            is_last = ind == (len(in_out) - 1)
            self.ups.append(
                nn.ModuleList(
                    [
                        block_klass(dim_out + dim_in, dim_out, time_emb_dim=time_dim),
                        block_klass(dim_out + dim_in, dim_out, time_emb_dim=time_dim),
                        Residual(PreNorm(dim_out, LinearAttention(dim_out))),
                        Upsample(dim_out, dim_in)
                        if not is_last
                        else nn.Conv2d(dim_out, dim_in, 3, padding=1),
                    ]
                )
            )
        self.out_dim = default(out_dim, channels)
        self.final_res_block = block_klass(dim * 2, dim, time_emb_dim=time_dim)
        self.final_conv = nn.Conv2d(dim, self.out_dim, 1)

    def forward(self, x, time, x_self_cond=None):
        if self.self_condition:
            x_self_cond = default(x_self_cond, lambda: torch.zeros_like(x))
            x = torch.cat((x_self_cond, x), dim=1)

        x = self.init_conv(x)
        r = x.clone()

        t = self.time_mlp(time)

        h = []
        for block1, block2, attn, downsample in self.downs:
            x = block1(x, t)
            h.append(x)

            x = block2(x, t)
            x = attn(x)
            h.append(x)

            x = downsample(x)

        x = self.mid_block1(x, t)
        x = self.mid_attn(x)
        x = self.mid_block2(x, t)

        for block1, block2, attn, upsample in self.ups:
            x = torch.cat((x, h.pop()), dim=1)
            x = block1(x, t)

            x = torch.cat((x, h.pop()), dim=1)
            x = block2(x, t)
            x = attn(x)

            x = upsample(x)

        x = torch.cat((x, r), dim=1)
        x = self.final_res_block(x, t)
        return self.final_conv(x)

Define the forward diffusion process

forward diffusion process在 T T Gradually add noise from the true distribution to the image over Tvariance schedule time steps, according to what happens. The original DDPM authors adoptedlinear schedule:

We set the variance of the forward process to be a linearly increasing constant from β 1 = 1 0 − 4 \beta_1=10^{-4}b1=104 to β T = 0.02 \beta_T=0.02 bT=0.02.

However, it was shown in (Nichol et al., 2021) that better results can be obtained using a cosine schedule. Next, we define TTDifferent schedules for T time steps (we will choose one later):

def cosine_beta_schedule(timesteps, s=0.008):
    """
    cosine schedule as proposed in https://arxiv.org/abs/2102.09672
    """
    steps = timesteps + 1
    x = torch.linspace(0, timesteps, steps)
    alphas_cumprod = torch.cos(((x / timesteps) + s) / (1 + s) * torch.pi * 0.5) ** 2
    alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
    betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
    return torch.clip(betas, 0.0001, 0.9999)

def linear_beta_schedule(timesteps):
    beta_start = 0.0001
    beta_end = 0.02
    return torch.linspace(beta_start, beta_end, timesteps)

def quadratic_beta_schedule(timesteps):
    beta_start = 0.0001
    beta_end = 0.02
    return torch.linspace(beta_start**0.5, beta_end**0.5, timesteps) ** 2

def sigmoid_beta_schedule(timesteps):
    beta_start = 0.0001
    beta_end = 0.02
    betas = torch.linspace(-6, 6, timesteps)
    return torch.sigmoid(betas) * (beta_end - beta_start) + beta_start

First, use T=300 time steps linear schedule, and start from β t \beta_tbtDefine the variables we need, for example, the cumulative product of variances KaTeX parse error: Undefined control sequence: \bat at position 1: \̲b̲a̲t̲{\alpha}_t . Each variable below is just a 1D tensor, stored from ttt toTTThe value of T. Note that we also define an extract function which allows us to followttt fetches the index of a batch.

timesteps = 300

# define beta schedule
betas = linear_beta_schedule(timesteps=timesteps)

# define alphas
alphas = 1. - betas
alphas_cumprod = torch.cumprod(alphas, axis=0)
alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value=1.0)
sqrt_recip_alphas = torch.sqrt(1.0 / alphas)

# calculations for diffusion q(x_t | x_{t-1}) and others
sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod)
sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - alphas_cumprod)

# calculations for posterior q(x_{t-1} | x_t, x_0)
posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)

def extract(a, t, x_shape):
    batch_size = t.shape[0]
    out = a.gather(-1, t.cpu())
    return out.reshape(batch_size, *((1,) * (len(x_shape) - 1))).to(t.device)

We will use cat images to illustrate how noise is added at each time step of the diffusion process:

from PIL import Image
import requests

url = 'http://images.cocodataset.org/val2017/000000039769.jpg'
image = Image.open(requests.get(url, stream=True).raw) # PIL image of shape HWC
image

cat
Add noise to Pytorch tensors, not Pillow Images. First define image transformations that are able to convert PIL images to Pytorch tensors (on which noise can be added) and vice versa.

These conversions are very simple: we first divide by 255 (the result can be in the [0,1] range), and then make sure they are in the [-1,1] range. The DDPM article mentions:

We assume that the image data consists of integers in the set {0,1, . . . , 255}, and then scale linearly to [−1, 1]. This ensures that the neural network can be reversed from the standard normal prior p ( x T ) p(x_T)p(xT) and operate on a uniformly scaled input. "

from torchvision.transforms import Compose, ToTensor, Lambda, ToPILImage, CenterCrop, Resize

image_size = 128
transform = Compose([
    Resize(image_size),
    CenterCrop(image_size),
    ToTensor(),  # turn into torch Tensor of shape CHW, divide by 255
    Lambda(lambda t: (t * 2) - 1),

])

x_start = transform(image).unsqueeze(0)
x_start.shape

Output:
torch.Size([1, 3, 128, 128])

Additionally, the reverse transform is defined , which takes a PyTorch tensor containing [-1,1] and transforms them back into PIL images:

import numpy as np

reverse_transform = Compose([
     Lambda(lambda t: (t + 1) / 2),
     Lambda(lambda t: t.permute(1, 2, 0)), # CHW to HWC
     Lambda(lambda t: t * 255.),
     Lambda(lambda t: t.numpy().astype(np.uint8)),
     ToPILImage(),
])
reverse_transform(x_start.squeeze())

cat
Now, the forward diffusion process can be defined as in the paper:

# forward diffusion (using the nice property)
def q_sample(x_start, t, noise=None):
    if noise is None:
        noise = torch.randn_like(x_start)

    sqrt_alphas_cumprod_t = extract(sqrt_alphas_cumprod, t, x_start.shape)
    sqrt_one_minus_alphas_cumprod_t = extract(
        sqrt_one_minus_alphas_cumprod, t, x_start.shape
    )

    return sqrt_alphas_cumprod_t * x_start + sqrt_one_minus_alphas_cumprod_t * noise

Test at a specific time step:

def get_noisy_image(x_start, t):
  # add noise
  x_noisy = q_sample(x_start, t=t)

  # turn back into PIL image
  noisy_image = reverse_transform(x_noisy.squeeze())

  return noisy_image
# take time step
t = torch.tensor([40])

get_noisy_image(x_start, t)

noise cat
Visualize the results at different time steps:

import matplotlib.pyplot as plt

# use seed for reproducability
torch.manual_seed(0)

# source: https://pytorch.org/vision/stable/auto_examples/plot_transforms.html#sphx-glr-auto-examples-plot-transforms-py
def plot(imgs, with_orig=False, row_title=None, **imshow_kwargs):
    if not isinstance(imgs[0], list):
        # Make a 2d grid even if there's just 1 row
        imgs = [imgs]

    num_rows = len(imgs)
    num_cols = len(imgs[0]) + with_orig
    fig, axs = plt.subplots(figsize=(200,200), nrows=num_rows, ncols=num_cols, squeeze=False)
    for row_idx, row in enumerate(imgs):
        row = [image] + row if with_orig else row
        for col_idx, img in enumerate(row):
            ax = axs[row_idx, col_idx]
            ax.imshow(np.asarray(img), **imshow_kwargs)
            ax.set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])

    if with_orig:
        axs[0, 0].set(title='Original image')
        axs[0, 0].title.set_size(8)
    if row_title is not None:
        for row_idx in range(num_rows):
            axs[row_idx, 0].set(ylabel=row_title[row_idx])

    plt.tight_layout()
plot([get_noisy_image(x_start, torch.tensor([t])) for t in [0, 50, 100, 150, 199]])

Visualization of results at different time steps
Define a loss function given a model:

def p_losses(denoise_model, x_start, t, noise=None, loss_type="l1"):
    if noise is None:
        noise = torch.randn_like(x_start)

    x_noisy = q_sample(x_start=x_start, t=t, noise=noise)
    predicted_noise = denoise_model(x_noisy, t)

    if loss_type == 'l1':
        loss = F.l1_loss(noise, predicted_noise)
    elif loss_type == 'l2':
        loss = F.mse_loss(noise, predicted_noise)
    elif loss_type == "huber":
        loss = F.smooth_l1_loss(noise, predicted_noise)
    else:
        raise NotImplementedError()

    return loss

dnoise_modelIt is the U-Net defined above. Use Huber loss between true noise and predicted noise .

Define PyTorch dataset + DataLoader

A regular PyTorch dataset is defined here. This dataset consists only of images from real datasets (such as Fashion, MNIST, CIFAR-10 or ImageNet), scaled linearly to [ − 1 , 1 ] [-1,1][1,1 ] .
Each image is resized to the same size and randomly flipped horizontally. From the paper:

We used random horizontal flipping during the training of CIFAR10; we tried training with and without flipping and found that flipping slightly improved sample quality.

Here, Datasetsthe library is used to easily hubload the Fashion MNIST dataset from . This dataset consists of images that already have the same resolution, i.e. 28x28.

from datasets import load_dataset

# load dataset from the hub
dataset = load_dataset('fashion_mnist')
image_size = 28
channels = 1
batch_size = 128

Next, define a function that will be applied on the fly over the entire dataset. Use that with_transformfunction for this. The function just applies some basic image preprocessing: random horizontal flipping, rescaling and finally making them in the [-1,1] range.

from torchvision import transforms
from torch.utils.data import DataLoader

# define image transformations(e.g. using torchvision)
transform = Compose([
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Lambda(lambda t: (t * 2) -1)
])

# define function
def transforms(examples):
   examples["pixel_values"] = [transform(image.convert("L")) for image in examples["image"]]
   del examples["image"]

   return examples

transformed_dataset = dataset.with_transform(transforms).remove_columns("label")

# create dataloader
dataloader = DataLoader(transformed_dataset["train"], batch_size=batch_size, shuffle=True)
batch = next(iter(dataloader))
print(batch.keys())  # dict_keys(['pixel_values'])

sampling

Since the model will be sampled during training (to track progress), the code below is defined. The sampling method is summarized as follows:
Sampling
The generation of new images from the diffusion model is achieved by an inverse diffusion process: from TTStarting at T , pure noise is sampled from a Gaussian distribution, and then the neural network is used to gradually denoise (using the conditional probabilities it has learned), until at time stept = 0 t = 0t=0 ends. As shown above, a slightly denoised imagext − 1 x_{t-1}xt1A reparameterization of the mean is interpolated by using our noise predictor. Note that the variance is known in advance.

Ideally, you end up with an image that looks like it came from the real data distribution. The code below achieves this.

@torch.no_grad()
def p_sample(model, x, t, t_index):
    betas_t = extract(betas, t, x.shape)
    sqrt_one_minus_alphas_cumprod_t = extract(
        sqrt_one_minus_alphas_cumprod, t, x.shape
    )
    sqrt_recip_alphas_t = extract(sqrt_recip_alphas, t, x.shape)
    
    # Equation 11 in the paper
    # Use our model (noise predictor) to predict the mean
    model_mean = sqrt_recip_alphas_t * (
        x - betas_t * model(x, t) / sqrt_one_minus_alphas_cumprod_t
    )

    if t_index == 0:
        return model_mean
    else:
        posterior_variance_t = extract(posterior_variance, t, x.shape)
        noise = torch.randn_like(x)
        # Algorithm 2 line 4:
        return model_mean + torch.sqrt(posterior_variance_t) * noise 

# Algorithm 2 (including returning all images)
@torch.no_grad()
def p_sample_loop(model, shape):
    device = next(model.parameters()).device

    b = shape[0]
    # start from pure noise (for each example in the batch)
    img = torch.randn(shape, device=device)
    imgs = []

    for i in tqdm(reversed(range(0, timesteps)), desc='sampling loop time step', total=timesteps):
        img = p_sample(model, img, torch.full((b,), i, device=device, dtype=torch.long), i)
        imgs.append(img.cpu().numpy())
    return imgs

@torch.no_grad()
def sample(model, image_size, batch_size=16, channels=3):
    return p_sample_loop(model, shape=(batch_size, channels, image_size, image_size))

training model

Next, train the model in the normal PyTorch way. We also define some logic to periodically save the resulting image, using the defined samplemethod defined above.

from pathlib import Path

def num_to_groups(num, divisor):
    groups = num // divisor
    remainder = num % divisor
    arr = [divisor] * groups
    if remainder > 0:
        arr.append(remainder)
    return arr

results_folder = Path("./results")
results_folder.mkdir(exist_ok = True)
save_and_sample_every = 1000

Below, the model is defined and moved to the GPU, and a standard optimizer (Adam) is also defined.

from torch.optim import Adam

device = "cuda" if torch.cuda.is_available() else "cpu"

model = Unet(
    dim=image_size,
    channels=channels,
    dim_mults=(1, 2, 4,)
)
model.to(device)

optimizer = Adam(model.parameters(), lr=1e-3)

Start training:

from torchvision.utils import save_image

epochs = 6

for epoch in range(epochs):
    for step, batch in enumerate(dataloader):
      optimizer.zero_grad()

      batch_size = batch["pixel_values"].shape[0]
      batch = batch["pixel_values"].to(device)

      # Algorithm 1 line 3: sample t uniformally for every example in the batch
      t = torch.randint(0, timesteps, (batch_size,), device=device).long()

      loss = p_losses(model, batch, t, loss_type="huber")

      if step % 100 == 0:
        print("Loss:", loss.item())

      loss.backward()
      optimizer.step()

      # save generated images
      if step != 0 and step % save_and_sample_every == 0:
        milestone = step // save_and_sample_every
        batches = num_to_groups(4, batch_size)
        all_images_list = list(map(lambda n: sample(model, batch_size=n, channels=channels), batches))
        all_images = torch.cat(all_images_list, dim=0)
        all_images = (all_images + 1) * 0.5
        save_image(all_images, str(results_folder / f'sample-{
      
      milestone}.png'), nrow = 6)

Training process:

Loss: 0.5570111274719238
Loss: 0.06583500653505325
Loss: 0.06006840616464615
Loss: 0.051015421748161316
Loss: 0.0394190177321434
Loss: 0.04075610265135765
Loss: 0.039987701922655106
Loss: 0.03415030241012573
Loss: 0.030019590631127357
Loss: 0.036297883838415146
Loss: 0.037256866693496704
Loss: 0.03864285722374916
Loss: 0.03298967331647873
Loss: 0.03331328555941582
Loss: 0.027535393834114075
Loss: 0.03803558647632599
Loss: 0.03721949830651283
Loss: 0.03478413075208664
Loss: 0.03918925300240517
Loss: 0.03608154132962227
Loss: 0.027622627094388008
Loss: 0.02948344498872757
Loss: 0.029868196696043015
Loss: 0.03154699504375458
Loss: 0.029723389074206352
Loss: 0.039195798337459564
Loss: 0.032130151987075806
Loss: 0.031276602298021317
Loss: 0.03440115600824356
Loss: 0.030476151034235954

sampling

To sample from the model, you can use the sampling function defined above:

# sample 64 images
samples = sample(model, image_size=image_size, batch_size=64, channels=channels)

# show a random one
random_index = 5
plt.imshow(samples[-1][random_index].reshape(image_size, image_size, channels), cmap="gray")

Sampling result
It looks like the model was able to generate a nice t-shirt! Keep in mind that the dataset used for training has a very low resolution (28x28). It is also possible to create a gif of the denoising process:

import matplotlib.animation as animation

random_index = 53

fig = plt.figure()
ims = []
for i in range(timesteps):
    im = plt.imshow(samples[i][random_index].reshape(image_size, image_size, channels), cmap="gray", animated=True)
    ims.append([im])

animate = animation.ArtistAnimation(fig, ims, interval=50, blit=True, repeat_delay=1000)
animate.save('diffusion.gif')
plt.show()

gif image

follow-up reading

Note that the DDPM paper shows that diffusion models are a promising direction for (un)conditional image generation. DDPM has been (greatly) improved since it was proposed, especially in terms of text-conditioned image generation. Below, some important (but far from exhaustive) follow-ups until 7 June 2022 are listed:

  • Improved Denoising Diffusion Probabilistic Models (Nichol et al., 2021): Discovers that learning the variance of a conditional distribution (in addition to the mean) helps improve performance .
  • Cascaded Diffusion Models for High Fidelity Image Generation (Ho et al., 2021): Cascaded diffusion is introduced , which consists of a pipeline of multiple diffusion models to generate images of increasing resolution for high-fidelity image synthesis .
  • Diffusion Models Beat GANs on Image Synthesis (Dhariwal et al., 2021): It shows that the diffusion model can achieve better results than the SOTA generative model by improving the U-Net architecture and introducing classifier guidance .
  • Classifier-Free Diffusion Guidance (Ho et al., 2021): Shows that instead of using a classifier to guide the diffusion model, a single neural network is required to jointly train the conditional and unconditional diffusion models .
  • Hierarchical Text-Conditional Image Generation with CLIP Latents (DALL-E 2) (Ramesh et al., 2022): Convert text captions to CLIP image embeddings using priors, and then decode them into images using a diffusion model .
  • Photorealistic Text-to-Image Diffusion Models with Deep Language Understanding (ImageGen) (Saharia et al., 2022): Shows that combining large pre-trained language models (such as T5) with cascaded diffusion is well suited for text-to-image synthesis

reference link

  1. The Annotated Diffusion Model
  2. Take you to understand the diffusion model DDPM
  3. Diffusion model brand new course: Diffusion model realized from 0 to 1!
  4. Denoising Diffusion Probabilitistic Models
  5. "Diffusion Models Beat GANs on Image Synthesis" reading notes
  6. How Diffusion Models Work
  7. DDPM cross entropy loss function derivation
  8. Brief introduction of DDPM (Denoising Diffusion Probabilistic Models) diffusion model
  9. What are Diffusion Models?
  10. Understand the Diffusion Model from shallow to deep
  11. What is Diffusion Model?
  12. Probabilistic Diffusion Model probability diffusion model theory and detailed interpretation of complete PyTorch code
  13. Denoising Diffusion Probabilistic Model, in Pytorch

Guess you like

Origin blog.csdn.net/ARPOSPF/article/details/132219165
Recommended