Diffusion's cifar/mnist dataset


Code source: https://github.com/abarankab/DDPM

Solution to the problem of wandb:

step1: follow this https://blog.csdn.net/weixin_43164054/article/details/124156206 step by step step2: modify project_name="cifar", and then execute python train_cifar.py if there is an error "wandb: ERROR It appears that you do not have permission to access the requested resource.", see this https://blog.csdn.net/weixin_43835996/article/details/126955917

cifar10 dataset

After configuring wandb, remove
DDPM/scripts/train_mnist.py according to the source code on githubentity='treaptofun'

 run = wandb.init(
                project=args.project_name,
                
                config=vars(args),
                name=args.run_name,
            )
            # entity='treaptofun',

Then you can train normally

mnist dataset

For the mnist data set, the following two files need to be modified

ddpm/script_utils.py

line 90: img_channel=1, because the cifar image is 3 channels, and the mnist image is 1 channel
line 101: initial_pad=2, because the image size of the cifar data set is 32, which is an exponential multiple of 2, and the downsampling process is divided by 2 can always be divisible; and the image size of mnist is 28, so the padding is 32, that is, the
image size of initial_pad=2 line 120:cifar10 is 32 32, and the image size of mnist is 28 28,

import argparse
import torchvision
import torch.nn.functional as F

from .unet import UNet
from .diffusion import (
    GaussianDiffusion,
    generate_linear_schedule,
    generate_cosine_schedule,
)


def cycle(dl):
    """
    https://github.com/lucidrains/denoising-diffusion-pytorch/
    """
    while True:
        for data in dl:
            yield data

def get_transform():
    class RescaleChannels(object):
        def __call__(self, sample):
            return 2 * sample - 1

    return torchvision.transforms.Compose([
        torchvision.transforms.ToTensor(),
        RescaleChannels(),
    ])


def str2bool(v):
    """
    https://stackoverflow.com/questions/15008758/parsing-boolean-values-with-argparse
    """
    if isinstance(v, bool):
        return v
    if v.lower() in ("yes", "true", "t", "y", "1"):
        return True
    elif v.lower() in ("no", "false", "f", "n", "0"):
        return False
    else:
        raise argparse.ArgumentTypeError("boolean value expected")


def add_dict_to_argparser(parser, default_dict):
    """
    https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/script_util.py
    """
    for k, v in default_dict.items():
        v_type = type(v)
        if v is None:
            v_type = str
        elif isinstance(v, bool):
            v_type = str2bool
        parser.add_argument(f"--{
      
      k}", default=v, type=v_type)


def diffusion_defaults():
    defaults = dict(
        num_timesteps=1000,
        schedule="linear",
        loss_type="l2",
        use_labels=False,

        base_channels=128,
        channel_mults=(1, 2, 2, 2),
        num_res_blocks=2,
        time_emb_dim=128 * 4,
        norm="gn",
        dropout=0.1,
        activation="silu",
        attention_resolutions=(1,),

        ema_decay=0.9999,
        ema_update_rate=1,
    )

    return defaults


def get_diffusion_from_args(args):
    activations = {
    
    
        "relu": F.relu,
        "mish": F.mish,
        "silu": F.silu,
    }
    # base_channels=128
    model = UNet(
        img_channels=1,

        base_channels=args.base_channels,
        channel_mults=args.channel_mults,
        time_emb_dim=args.time_emb_dim,
        norm=args.norm,
        dropout=args.dropout,
        activation=activations[args.activation],
        attention_resolutions=args.attention_resolutions,

        num_classes=None if not args.use_labels else 10,
        initial_pad=2,
    )
    # line102  在cifar中为initial_pad=0,  

    if args.schedule == "cosine":
        betas = generate_cosine_schedule(args.num_timesteps)
    else:
        betas = generate_linear_schedule(
            args.num_timesteps,
            args.schedule_low * 1000 / args.num_timesteps,
            args.schedule_high * 1000 / args.num_timesteps,
        )

    # 本py文件共修改了3处:line 90 ; line 101 ;line 120.
    # model, (32, 32), 3, 10,    
    # cifar10 的图片大小为32*32,3channel, mnist的图片大小为28*28,1channel
    
    
    diffusion = GaussianDiffusion(
        model, (28, 28), 1, 10,
        betas,
        ema_decay=args.ema_decay,
        ema_update_rate=args.ema_update_rate,
        ema_start=2000,
        loss_type=args.loss_type,
    )

    return diffusion

scripts/train_mnist.py

Remove entity='treaptofun'

import argparse
import datetime
import torch
import wandb

from torch.utils.data import DataLoader
from torchvision import datasets
from ddpm import script_utils


def main():
    args = create_argparser().parse_args()
    device = args.device

    try:
        diffusion = script_utils.get_diffusion_from_args(args).to(device)
        optimizer = torch.optim.Adam(diffusion.parameters(), lr=args.learning_rate)


        # 接着上次中断保存的参数继续训练
        if args.model_checkpoint is not None:
            diffusion.load_state_dict(torch.load(args.model_checkpoint))
        if args.optim_checkpoint is not None:
            optimizer.load_state_dict(torch.load(args.optim_checkpoint))

        if args.log_to_wandb:
            if args.project_name is None:
                raise ValueError("args.log_to_wandb set to True but args.project_name is None")

            # wandb.init(project="ddpm_cifar")

            run = wandb.init(
                project=args.project_name,
                
                config=vars(args),
                name=args.run_name,
            )
            # entity='treaptofun',

            wandb.watch(diffusion)

        batch_size = args.batch_size

        train_dataset = datasets.MNIST(
            root='../dataset/mnist/mnist_train',
            train=True,
            download=True,
            transform=script_utils.get_transform(),
        )

        test_dataset = datasets.MNIST(
            root='../dataset/mnist/mnist_test',
            train=False,
            download=True,
            transform=script_utils.get_transform(),
        )

        train_loader = script_utils.cycle(DataLoader(
            train_dataset,
            batch_size=batch_size,
            shuffle=True,
            drop_last=True,
            num_workers=2,
        ))
        test_loader = DataLoader(test_dataset, batch_size=batch_size, drop_last=True, num_workers=2)
        
        acc_train_loss = 0

        for iteration in range(1, args.iterations + 1):
            diffusion.train()

            x, y = next(train_loader)
            x = x.to(device)
            y = y.to(device)

            if args.use_labels:
                loss = diffusion(x, y)
            else:
                loss = diffusion(x)

            acc_train_loss += loss.item()

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            diffusion.update_ema()
            
            if iteration % args.log_rate == 0:
                test_loss = 0
                with torch.no_grad():
                    diffusion.eval()
                    for x, y in test_loader:
                        x = x.to(device)
                        y = y.to(device)

                        if args.use_labels:
                            loss = diffusion(x, y)
                        else:
                            loss = diffusion(x)

                        test_loss += loss.item()
                
                if args.use_labels:
                    samples = diffusion.sample(10, device, y=torch.arange(10, device=device))
                else:
                    samples = diffusion.sample(10, device)
                
                samples = ((samples + 1) / 2).clip(0, 1).permute(0, 2, 3, 1).numpy()

                test_loss /= len(test_loader)
                acc_train_loss /= args.log_rate

                wandb.log({
    
    
                    "test_loss": test_loss,
                    "train_loss": acc_train_loss,
                    "samples": [wandb.Image(sample) for sample in samples],
                })

                acc_train_loss = 0
            
            if iteration % args.checkpoint_rate == 0:
                model_filename = f"{
      
      args.log_dir}/{
      
      args.project_name}-{
      
      args.run_name}-iteration-{
      
      iteration}-model.pth"
                optim_filename = f"{
      
      args.log_dir}/{
      
      args.project_name}-{
      
      args.run_name}-iteration-{
      
      iteration}-optim.pth"

                torch.save(diffusion.state_dict(), model_filename)
                torch.save(optimizer.state_dict(), optim_filename)
        
        if args.log_to_wandb:
            run.finish()
    except KeyboardInterrupt:
        if args.log_to_wandb:
            run.finish()
        print("Keyboard interrupt, run finished early")


def create_argparser():
    device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
    run_name = datetime.datetime.now().strftime("ddpm-%Y-%m-%d-%H-%M")
    defaults = dict(
        learning_rate=2e-4,
        batch_size=128,
        iterations=80000,

        log_to_wandb=True,
        log_rate=1000,
        checkpoint_rate=1000,
        log_dir="./ddpm_logs_mnist",
        project_name="mnist",
        run_name=run_name,

        model_checkpoint=None,
        optim_checkpoint=None,

        schedule_low=1e-4,
        schedule_high=0.02,

        device=device,
    )
    defaults.update(script_utils.diffusion_defaults())

    parser = argparse.ArgumentParser()
    script_utils.add_dict_to_argparser(parser, defaults)
    return parser


if __name__ == "__main__":
    main()

Training commands executed on the command line:
python train.py
Sampling commands executed on the command line
python sample_images.py --model_path "your model path" --save_dir "your save img path" --schedule cosine

Show sampling results

import matplotlib.pyplot as plt
import numpy as np
import os

def show(num_imgs, dir_path):
    ''' 
    num_imgs: 要展示的图片的张数
    dir_path:图片的路径
    '''
    img_names=os.listdir (dir_path)
    img_names.sort(key=lambda x:int(x.split('.')[0]))

    plt.figure(figsize=(20,5)) # 画布大小
    N=2
    M=10
    #形成NxM大小的画布
    for i in range(num_imgs):#有张图片
        path = dir_path + img_names[i]
        img = plt.imread(path)
        plt.subplot(N,M,i+1)#表示第i张图片,下标只能从1开始,不能从0,
        plt.imshow(img)
        plt.title(img_names[i],color='black')
        #下面两行是消除每张图片自己单独的横纵坐标,不然每张图片会有单独的横纵坐标,影响美观
        plt.xticks([])
        plt.yticks([])
    plt.show()

print("mnist generation results:")
show(20, './scripts/save_dir_mnist/')  # 模型训练出来的保存的结果

The name here is just the serial number of the predicted picture, not the predicted label!
insert image description here

Training and sampling process without label

Training process:

def get_losses(self, x, t, y):
    noise = torch.randn_like(x)

    perturbed_x = self.perturb_x(x, t, noise)
    estimated_noise = self.model(perturbed_x, t, y)    # 输入到Model的是加噪后的图片
    # 这个model预测出来的噪声是每个像素点位置上的噪声!!!
    # 因为这个model的output的形状和x是一样的,[batch, img_channel, h, w]

    if self.loss_type == "l1":
        loss = F.l1_loss(estimated_noise, noise)
    elif self.loss_type == "l2":
        loss = F.mse_loss(estimated_noise, noise)

    return loss
  • x: (batch_size, img_channel, h, w)
  • t: (batch_size, )
    Randomly generate b times in the interval [0, num_timesteps], the diffusion process is not carried out step by step, t is a tensor with a size of batch

Let me explain: how this t is added to the image x.
t is initially a tensor of (batch_size,) shape, which becomes (batch_size, img_channel) after linearization, and then becomes (batch_size, img_channel) after dimension expansion , 1, 1), it can be added to x: (batch_size, img_channel, h, w) through the broadcast mechanism, that is, the value of t added to all pixels on the same channel of x is the same.

  • perturb_x: According to the formula xt = α t ˉ . x 0 + 1 − α t ˉ . z x_t = \sqrt{\bar{\alpha_t}}.x_0 + \sqrt{1 - \bar{\alpha_t}}.zxt=atˉ .x0+1atˉ .z to x 0x_0x0Add noise, the shape of perturbed_x is (batch_size, img_channel, h, w)
  • model(perturbed_x, t, y) : Input the image after adding noise, and the corresponding time t, the model predicts the added noise, through convolution, activation, downsampling and upsampling of perturbed_x, the final model The shape of the output is still (batch_size, img_channel, h, w), and the output of the model is the noise added by the prediction. Then the noise predicted here is the noise that is predicted to be added to the position of each pixel!
  • Use l1 or l2 loss function to calculate the loss.

sampling process

    @torch.no_grad()
    def sample(self, batch_size, device, y=None, use_ema=True):
        if y is not None and batch_size != len(y):
            raise ValueError("sample batch size different from length of given y")

        x = torch.randn(batch_size, self.img_channels, *self.img_size, device=device)
        
        for t in range(self.num_timesteps - 1, -1, -1): # 从T=[t-1]到T=[0]
            t_batch = torch.tensor([t], device=device).repeat(batch_size)
            x = self.remove_noise(x, t_batch, y, use_ema)

            if t > 0:
                x += extract(self.sigma, t_batch, x.shape) * torch.randn_like(x)
        
        return x.cpu().detach()

x t − 1 = 1 α t ( x t − 1 − α t 1 − α ˉ t ϵ θ ( x t , t ) ) + σ t z \mathbf{x}_{t-1}=\frac{1}{\sqrt{\alpha_{t}}}\left(\mathbf{x}_{t}-\frac{1-\alpha_{t}}{\sqrt{1-\bar{\alpha}_{t}}} \boldsymbol{\epsilon}_{\theta}\left(\mathbf{x}_{t}, t\right)\right)+\sigma_{t} \mathbf{z} xt1=at 1(xt1aˉt 1 atϵi(xt,t))+ptz

  • x: randomly generated noise as the initial value, batch_size is the number of pictures you want to generate, for example, you want to generate 1k pictures
  • t_batch: That is to say, the denoising of x is performed in batches. Our goal is x T , x T − 1 , x T − 2 . . . x 1 , x 0 x_T, x_{T-1},x_{ T-2}...x_1,x_0xT,xT1,xT2...x1,x0, because x has batch_size, t_batch is to denoise the batch_size pictures at the same time
  • remove_noise: 1 α t ( x t − 1 − α t 1 − α ˉ t ϵ θ ( x t , t ) ) \frac{1}{\sqrt{\alpha_{t}}}\left(\mathbf{x}_{t}-\frac{1-\alpha_{t}}{\sqrt{1-\bar{\alpha}_{t}}} \boldsymbol{\epsilon}_{\theta}\left(\mathbf{x}_{t}, t\right)\right) at 1(xt1aˉt 1 atϵi(xt,t))
  • if t>0: add a random noise σ tz \sigma_{t} \mathbf{z}ptz , why add a random noise in the sampling process? In order to simulate the randomness of Brownian motion, when t=0, it means that it has reachedx 0 x_0x0That is, the original image is obtained in the last step, and there is no need to add noise to the original image!

Conditional training and sampling

train

The conditional training process is to add the label y to the picture for training

self.class_bias = nn.Embedding(num_classes, out_channels) if num_classes is not None else None

out += self.class_bias(y)[:, :, None, None]

y: It is the label. Through nn.Embedding, y can be expressed as [batch_size, out_channels], and then expanded by [:, :, None, None], and y becomes [batch_size, out_channels, 1, 1], Then add it to x after various operations (that is, add it to out)

The operation on y here is very similar to the operation on time t

sampling

        if args.use_labels:
            for label in range(10):
                # 这个就是假设每一类的数量都是一样的,所以在生成标签的时候,每一类的标签y的数量是一样的
                # 比如我们想生成1k个图片,label一共有10种,所以每一类有100张
                y = torch.ones(args.num_images // 10, dtype=torch.long, device=device) * label
                samples = diffusion.sample(args.num_images // 10, device, y=y)

                for image_id in range(len(samples)):
                    image = ((samples[image_id] + 1) / 2).clip(0, 1)
                    torchvision.utils.save_image(image, f"{
      
      args.save_dir}/{
      
      label}-{
      
      image_id}.png")

The sampling process is the denoising process. The size of the removed noise is the noise predicted by our trained model. For conditional generation, we added labels during the training process, so the generated Sometimes we can also add a label to specify the noise image to denoise step by step to get x 0 x_0x0, then this x 0 x_0x0It is more likely to belong to the category of the specified label.

Contrasting conditional generation and unconditional generation

Hypothesis: There are three types of cats, dogs, and pigs in the original training data set, and the proportions of these three types are 0.2. 0.3 0.5

  • Conditional generation:
    We can specify which category to generate, for example, to generate 1k pictures, we specify label=cat, then about 999+ of the generated 1k pictures are all cats

  • Unconditional generation:
    You cannot specify which category to generate. For example, if you generate 1k pictures, about 200 of these 1k pictures are cats, 300 are dogs, and 500 are pigs.

Guess you like

Origin blog.csdn.net/weixin_43845922/article/details/129817081