Thesis
Portal: Denoising Diffusion Probabilistic Models
Reference: The Annotated Diffusion Model
Purpose of DDPM:
The noise image is sampled from the standard normal distribution, and the generated image similar to the training image is restored after T times of denoising, so as to complete the image generation task.
The DDPM approach:
① Diffusion process (noise addition process):
Continuously adding noise to the training image, after T times , makes the training image approximately become a noise image with independent standard normal distribution in each direction .
Each noise addition is recorded as q ( xt ∣ xt − 1 ) q(x_t|x_{t-1})q(xt∣xt−1) , where t refers to the current moment (t times of noise addition), t-1 refers to the previous moment (t-1 times of noise addition),xt x_txtRefers to the image at the current moment, xt − 1 x_{t-1}xt−1Refers to the image at the previous moment.
The whole process is a Markov chain , that is, the image at the current moment is only related to its previous moment, and has nothing to do with other moments. Set a sequence β β
of length Tβ,β t β_tbtMonotonically increasing in the (0,1) interval, the variance of adding noise at time t is β t β_tbt, the mean value is determined by β t β_tbtand xt x_txtjointly decided, you can write the current moment q ( xt ∣ xt − 1 ) q(x_t|x_{t-1})q(xt∣xt−1) and the entire diffusion processq ( x 1 : T ∣ x 0 ) q(x_{1:T}|x_0)q(x1:T∣x0) formula:
It can be found that the noise image xt x_t at any timextIt can be obtained from the initial moment image (original image) x 0 x_0x0和b bβ sequence to determine, defineα t = 1 − β t α_t=1-β_tat=1−bt,α ˉ = ∏ s = 1 t α s \bar{α}=\prod\limits_{s=1}^tα_saˉ=s=1∏tas, then:
when T approaches ∞, it can be considered that x T x_TxTis an independent standard normal distribution.
The diffusion process has nothing to do with the network, as long as the initial moment image x 0 x_0x0和b bβ sequence, the entire diffusion process can be obtained.
②Inverse diffusion process (denoising process):
Continuously denoise the noise image, after T times , the noise image can be restored to the initial moment image x 0 x_0x0.
Each denoising is recorded as p ( xt − 1 ∣ xt ) p(x_{t-1}|x_t)p(xt−1∣xt) , formula:
p ( xt − 1 ∣ xt ) p(x_{t-1}|x_t)p(xt−1∣xt) is difficult to solve directly, so the network is used for calculation.
There are many ways to use the network to fit the distribution. The author chooses to build the networkD θ ( xt , t ) D_θ(x_t,t)Di(xt,t ) (i iθ represents network parameters), by inputting xt and t, the output time t noisezt z_tzt(instead of directly predicting xt − 1 x_{t-1}xt−1), while assuming the noise zt z_tztThe variance of is determined, and only μ θ μ_θ is calculated in the loss calculationmi.
Loss function:
The purpose of generating the network is to make the generated image distribution p θ ( x 0 ) p_θ(x_0)pi(x0) as close as possible to the original image distributionq ( x 0 ) q(x_0)q(x0) , calculate its KL divergenceKL ( q ∣ ∣ p θ ) KL(q||p_θ)KL(q∣∣pi) , the upper bound of the negative log-likelihood functioncan be written, and the purpose of maximizing the likelihood function can be achieved by minimizing its upper bound:
expressed in the form of KL divergence, it can be written as the following formula:
Among them, the first termLT L_TLTWithout model parameters θ θθ , is a constant and has nothing to do with the model optimization process; the second termL t − 1 L_{t-1}Lt−1with the third term L 0 L_0L0Can be expanded and simplified ( L 0 L_0L0Can be seen as when t = 1 t=1t=L tat 1 − 1 L_{t-1}Lt−1).
Note that q ( xt − 1 ∣ xt , x 0 ) q(x_{t-1}|x_t,x_0)q(xt−1∣xt,x0) can be calculated, and can be simplifiedusingBayesian formulaandheavy parameter technique
loss functionL t − 1 L_{t-1}Lt−1(allowing t = 1 t=1t=1 to includeL 0 L_0L0Item):
The author removes the coefficients to obtain a simplified loss function Loss:
pair L 0 L_0L0discussion:
It can be found that when the value of x is in the (-1,1) interval, the network can only update the gradient, so the author compresses the input image data from [0,255] to [-1,1] to ensure the normal backpropagation process conduct.
DDPM training and sampling process:
The structure of DDPM:
Based on the U-Net network, modules such as position encoding, residual structure, attention mechanism and group standardization have been added.
train.py
import os
import torch
from torch.utils.data import DataLoader
import torch.nn.functional as F
import matplotlib.pyplot as plt
from tqdm.auto import tqdm
import numpy as np
from torchvision import transforms, datasets
from model import Unet # DDPM模型
# 定义4种生成β的方法,均需传入总步长T,返回β序列
def cosine_beta_schedule(timesteps, s=0.008):
steps = timesteps + 1
x = torch.linspace(0, timesteps, steps)
alphas_cumprod = torch.cos(((x / timesteps) + s) / (1 + s) * torch.pi * 0.5) ** 2
alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
return torch.clip(betas, 0.0001, 0.9999)
def linear_beta_schedule(timesteps):
beta_start = 0.0001
beta_end = 0.02
return torch.linspace(beta_start, beta_end, timesteps)
def quadratic_beta_schedule(timesteps):
beta_start = 0.0001
beta_end = 0.02
return torch.linspace(beta_start ** 0.5, beta_end ** 0.5, timesteps) ** 2
def sigmoid_beta_schedule(timesteps):
beta_start = 0.0001
beta_end = 0.02
betas = torch.linspace(-6, 6, timesteps)
return torch.sigmoid(betas) * (beta_end - beta_start) + beta_start
# 从序列a中取t时刻的值a[t](batch_size个),维度与x_shape相同,第一维为batch_size
def extract(a, t, x_shape):
batch_size = t.shape[0]
out = a.gather(-1, t.cpu())
return out.reshape(batch_size, *((1,) * (len(x_shape) - 1))).to(t.device)
# 扩散过程采样,即通过x0和t计算xt
def q_sample(x_start, t, noise=None):
if noise is None:
noise = torch.randn_like(x_start)
sqrt_alphas_cumprod_t = extract(sqrt_alphas_cumprod, t, x_start.shape)
sqrt_one_minus_alphas_cumpord_t = extract(sqrt_one_minus_alphas_cumpord, t, x_start.shape)
return sqrt_alphas_cumprod_t * x_start + sqrt_one_minus_alphas_cumpord_t * noise
# 损失函数loss,共3种计算方式,原文使用l2
def p_losses(denoise_model, x_start, t, noise=None, loss_type="l1"):
if noise is None:
noise = torch.randn_like(x_start)
x_noisy = q_sample(x_start, t, noise)
predicted_noise = denoise_model(x_noisy, t)
if loss_type == "l1":
loss = F.l1_loss(noise, predicted_noise)
elif loss_type == "l2":
loss = F.mse_loss(noise, predicted_noise)
elif loss_type == "huber":
loss = F.smooth_l1_loss(noise, predicted_noise)
else:
raise NotImplementedError()
return loss
# 逆扩散过程采样,即通过xt和t计算xt-1,此过程需要通过网络
@torch.no_grad()
def p_sample(model, x, t, t_index):
betas_t = extract(betas, t, x.shape)
sqrt_one_minus_alphas_cumpord_t = extract(sqrt_one_minus_alphas_cumpord, t, x.shape)
sqrt_recip_alphas_t = extract(sqrt_recip_alphas, t, x.shape)
model_mean = sqrt_recip_alphas_t * (x - betas_t * model(x, t) / sqrt_one_minus_alphas_cumpord_t)
if t_index == 0:
return model_mean
else:
posterior_variance_t = extract(posterior_variance, t, x.shape)
noise = torch.randn_like(x)
return model_mean + torch.sqrt(posterior_variance_t) * noise
# 逆扩散过程T次采样,即通过xT和T计算xi,获得每一个时刻的图像列表[xi],此过程需要通过网络
@torch.no_grad()
def p_sample_loop(model, shape):
device = next(model.parameters()).device
b = shape[0]
img = torch.randn(shape, device=device)
imgs = []
for i in tqdm(reversed(range(0, timesteps)), desc="sampling loop time step", total=timesteps):
img = p_sample(model, img, torch.full((b,), i, device=device, dtype=torch.long), i)
imgs.append(img.cpu())
return imgs
# 逆扩散过程T次采样,允许传入batch_size指定生成图片的个数,用于生成结果的可视化
@torch.no_grad()
def sample(model, image_size, batch_size=16, channels=1):
return p_sample_loop(model, shape=(batch_size, channels, image_size, image_size))
if __name__ == "__main__":
timesteps = 300 # 总步长T
# 以下参数均为序列(List),需要传入t获得对应t时刻的值 xt = X[t]
betas = linear_beta_schedule(timesteps=timesteps) # 选择一种方式,生成β(t)
alphas = 1. - betas # α(t)
alphas_cumprod = torch.cumprod(alphas, axis=0) # α的连乘序列,对应α_bar(t)
alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0),
value=1.0) # 将α_bar的最后一个值删除,在最开始添加1,对应前一个时刻的α_bar,即α_bar(t-1)
sqrt_recip_alphas = torch.sqrt(1. / alphas) # 1/根号下α(t)
sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod) # 根号下α_bar(t)
sqrt_one_minus_alphas_cumpord = torch.sqrt(1. - alphas_cumprod) # 根号下(1-α_bar(t))
posterior_variance = betas * (1. - alphas_cumprod_prev) / (
1. - alphas_cumprod) # β(t)x(1-α_bar(t-1))/(1-α_bar(t)),即β^~(t)
total_epochs = 10
image_size = 28
channels = 1
batch_size = 256
lr = 1e-3
os.makedirs("../dataset/mnist", exist_ok=True)
os.makedirs("images", exist_ok=True)
transform = transforms.Compose([
transforms.Resize(image_size),
transforms.CenterCrop(image_size),
transforms.ToTensor(),
transforms.Lambda(lambda t: (t * 2) - 1) # 此处将输入数据从(0,1)区间转换到(-1,1)区间
])
dataset = datasets.MNIST(root="../dataset/mnist", train=True, transform=transform, download=True)
reverse_transform = transforms.Compose([ # tensor转换为PIL图片
transforms.Lambda(lambda t: (t + 1) / 2),
transforms.Lambda(lambda t: t.permute(1, 2, 0)),
transforms.Lambda(lambda t: t * 255),
transforms.Lambda(lambda t: t.numpy().astype(np.uint8)),
transforms.ToPILImage()
])
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
device = "cuda" if torch.cuda.is_available() else "cpu"
model = Unet(dim=image_size, channels=channels, dim_mults=(1, 2, 4))
model.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
for epoch in range(total_epochs):
total_loss = 0
pbar = tqdm(total=len(dataloader), desc=f"Epoch {
epoch + 1}/{
total_epochs}", postfix=dict,
miniters=0.3)
for iter, (img, _) in enumerate(dataloader):
img = img.to(device)
optimizer.zero_grad()
batch_size = img.shape[0]
t = torch.randint(0, timesteps, (batch_size,), device=device).long()
loss = p_losses(model, img, t, loss_type="huber") # 选择loss计算方式,计算loss
loss.backward()
optimizer.step()
total_loss += loss.item()
pbar.set_postfix(**{
"Loss": loss.item()})
pbar.update(1)
pbar.close()
print("total_loss:%.4f" %
(total_loss / len(dataloader)))
# 展示一张图片的生成过程(去噪过程),每3步生成一张图片,共100张图片(在一幅图中展示)
val_images = sample(model, image_size, batch_size=1, channels=channels)
fig, axs = plt.subplots(10, 10, figsize=(20, 20))
plt.rc("text", color="blue")
for t in range(100):
i = t // 10
j = t % 10
image = val_images[t * 3 + 2].squeeze(0)
image = reverse_transform(image)
axs[i, j].imshow(image, cmap="gray")
axs[i, j].set_axis_off()
axs[i, j].set_title("$q(\mathbf{x}_{" + str(300 - 3 * t - 3) + "})$")
plt.savefig(f"images/{
epoch + 1}.png", bbox_inches='tight')
plt.close()
model.py
import math
from inspect import isfunction
from functools import partial
from einops import rearrange, reduce
from einops.layers.torch import Rearrange
import torch
from torch import nn, einsum
import torch.nn.functional as F
def exisit(x):
return x is not None
def default(val, d):
if exisit(val):
return val
return d() if isfunction(d) else d
class Residual(nn.Module):
def __init__(self, fn):
super(Residual, self).__init__()
self.fn = fn
def forward(self, x, *args, **kwargs):
return self.fn(x, *args, **kwargs) + x
def Upsample(dim, dim_out=None):
return nn.Sequential(
nn.Upsample(scale_factor=2, mode="nearest"),
nn.Conv2d(dim, default(dim_out, dim), 3, 1, 1)
)
def Downsample(dim, dim_out=None):
return nn.Sequential(
Rearrange("b c (h p1) (w p2) -> b (c p1 p2) h w", p1=2, p2=2),
nn.Conv2d(dim * 4, default(dim_out, dim), 1, 1, 0)
)
class SinusoidalPositionEmbeddings(nn.Module):
def __init__(self, dim):
super(SinusoidalPositionEmbeddings, self).__init__()
self.dim = dim
def forward(self, time):
device = time.device
half_dim = self.dim // 2
embeddings = math.log(10000) / (half_dim - 1)
embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings)
embeddings = time[:, None] * embeddings[None, :]
embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=1)
return embeddings
class WeightStandardizedConv2d(nn.Conv2d):
def forward(self, x):
eps = 1e-5 if x.dtype == torch.float32 else 1e-3
weight = self.weight
mean = reduce(weight, "o ... -> o 1 1 1", "mean")
var = reduce(weight, "o ... -> o 1 1 1", partial(torch.var, unbiased=False))
normalized_weight = (weight - mean) * (var + eps).rsqrt()
return F.conv2d(
x,
normalized_weight,
self.bias,
self.stride,
self.padding,
self.dilation,
self.groups
)
class Block(nn.Module):
def __init__(self, dim, dim_out, groups=8):
super(Block, self).__init__()
self.proj = WeightStandardizedConv2d(dim, dim_out, 3, padding=1)
self.norm = nn.GroupNorm(groups, dim_out)
self.act = nn.SiLU()
def forward(self, x, scale_shift=None):
x = self.proj(x)
x = self.norm(x)
if exisit(scale_shift):
scale, shift = scale_shift
x = x * (scale + 1) + shift
x = self.act(x)
return x
class ResnetBlock(nn.Module):
def __init__(self, dim, dim_out, *, time_emb_dim=None, groups=8):
super(ResnetBlock, self).__init__()
self.mlp = (
nn.Sequential(nn.SiLU(), nn.Linear(time_emb_dim, dim_out * 2)) if exisit(time_emb_dim) else None
)
self.block1 = Block(dim, dim_out, groups=groups)
self.block2 = Block(dim_out, dim_out, groups=groups)
self.res_conv = nn.Conv2d(dim, dim_out, 1, 1, 0) if dim != dim_out else nn.Identity()
def forward(self, x, time_emb=None):
scale_shift = None
if exisit(self.mlp) and exisit(time_emb):
time_emb = self.mlp(time_emb)
time_emb = rearrange(time_emb, "b c -> b c 1 1")
scale_shift = time_emb.chunk(2, dim=1)
h = self.block1(x, scale_shift=scale_shift)
h = self.block2(h)
return h + self.res_conv(x)
class Attention(nn.Module):
def __init__(self, dim, heads=4, dim_head=32):
super(Attention, self).__init__()
self.scale = dim_head ** -0.5
self.heads = heads
hidden_dim = dim_head * heads
self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, 1, 0, bias=False)
self.to_out = nn.Conv2d(hidden_dim, dim, 1, 1, 0)
def forward(self, x):
b, c, h, w = x.shape
qkv = self.to_qkv(x).chunk(3, dim=1)
q, k, v = map(
lambda t: rearrange(t, "b (h c) x y -> b h c (x y)", h=self.heads), qkv
)
q = q * self.scale
sim = einsum("b h d i, b h d j -> b h i j", q, k)
sim = sim - sim.amax(dim=-1, keepdim=True).detach()
attn = sim.softmax(dim=-1)
out = einsum("b h i j, b h d j -> b h i d", attn, v)
out = rearrange(out, "b h (x y) d -> b (h d) x y", x=h, y=w)
return self.to_out(out)
class LinearAttention(nn.Module):
def __init__(self, dim, heads=4, dim_head=32):
super(LinearAttention, self).__init__()
self.scale = dim_head ** -0.5
self.heads = heads
hidden_dim = dim_head * heads
self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, 1, 0, bias=False)
self.to_out = nn.Sequential(
nn.Conv2d(hidden_dim, dim, 1, 1, 0),
nn.GroupNorm(1, dim)
)
def forward(self, x):
b, c, h, w = x.shape
qkv = self.to_qkv(x).chunk(3, dim=1)
q, k, v = map(
lambda t: rearrange(t, "b (h c) x y -> b h c (x y)", h=self.heads), qkv
)
q = q.softmax(dim=-2)
k = k.softmax(dim=-1)
q = q * self.scale
context = torch.einsum("b h d n, b h e n -> b h d e", k, v)
out = torch.einsum("b h d e, b h d n -> b h e n", context, q)
out = rearrange(out, "b h c (x y) -> b (h c) x y", h=self.heads, x=h, y=w)
return self.to_out(out)
class PreNorm(nn.Module):
def __init__(self, dim, fn):
super(PreNorm, self).__init__()
self.fn = fn
self.norm = nn.GroupNorm(1, dim)
def forward(self, x):
x = self.norm(x)
return self.fn(x)
class Unet(nn.Module):
def __init__(self, dim, init_dim=None, out_dim=None, dim_mults=(1, 2, 4, 8), channels=3, self_condition=False,
resnet_block_groups=4):
super(Unet, self).__init__()
self.channels = channels
self.self_condition = self_condition
input_channels = channels * (2 if self_condition else 1)
init_dim = default(init_dim, dim)
self.init_conv = nn.Conv2d(input_channels, init_dim, 1, 1, 0)
dims = [init_dim, *map(lambda m: dim * m, dim_mults)]
in_out = list(zip(dims[:-1], dims[1:]))
block_klass = partial(ResnetBlock, groups=resnet_block_groups)
time_dim = dim * 4
self.time_mlp = nn.Sequential(
SinusoidalPositionEmbeddings(dim),
nn.Linear(dim, time_dim),
nn.GELU(),
nn.Linear(time_dim, time_dim)
)
self.downs = nn.ModuleList([])
self.ups = nn.ModuleList([])
num_resolutions = len(in_out)
for ind, (dim_in, dim_out) in enumerate(in_out):
is_last = ind >= (num_resolutions - 1)
self.downs.append(
nn.ModuleList([
block_klass(dim_in, dim_in, time_emb_dim=time_dim),
block_klass(dim_in, dim_in, time_emb_dim=time_dim),
Residual(PreNorm(dim_in, LinearAttention(dim_in))),
Downsample(dim_in, dim_out) if not is_last else nn.Conv2d(dim_in, dim_out, 3, 1, 1)
])
)
mid_dim = dims[-1]
self.mid_block1 = block_klass(mid_dim, mid_dim, time_emb_dim=time_dim)
self.mid_attn = Residual(PreNorm(mid_dim, Attention(mid_dim)))
self.mid_block2 = block_klass(mid_dim, mid_dim, time_emb_dim=time_dim)
for ind, (dim_in, dim_out) in enumerate(reversed(in_out)):
is_last = ind == (len(in_out) - 1)
self.ups.append(
nn.ModuleList([
block_klass(dim_out + dim_in, dim_out, time_emb_dim=time_dim),
block_klass(dim_out + dim_in, dim_out, time_emb_dim=time_dim),
Residual(PreNorm(dim_out, LinearAttention(dim_out))),
Upsample(dim_out, dim_in) if not is_last else nn.Conv2d(dim_out, dim_in, 3, 1, 1)
])
)
self.out_dim = default(out_dim, channels)
self.final_res_block = block_klass(dim * 2, dim, time_emb_dim=time_dim)
self.final_conv = nn.Conv2d(dim, self.out_dim, 1, 1, 0)
def forward(self, x, time, x_self_cond=None):
if self.self_condition:
x_self_cond = default(x_self_cond, lambda: torch.zeros_like(x))
x = torch.cat((x_self_cond, x), dim=1)
x = self.init_conv(x)
r = x.clone()
t = self.time_mlp(time)
h = []
for block1, block2, attn, downsample in self.downs:
x = block1(x, t)
h.append(x)
x = block2(x, t)
x = attn(x)
h.append(x)
x = downsample(x)
x = self.mid_block1(x, t)
x = self.mid_attn(x)
x = self.mid_block2(x, t)
for block1, block2, attn, upsample in self.ups:
x = torch.cat((x, h.pop()), dim=1)
x = block1(x, t)
x = torch.cat((x, h.pop()), dim=1)
x = block2(x, t)
x = attn(x)
x = upsample(x)
x = torch.cat((x, r), dim=1)
x = self.final_res_block(x, t)
return self.final_conv(x)
if __name__ == "__main__":
model = Unet(28)
print(model)
Generate image (denoising process) schematic diagram: