Wasserstein GAN (Little White Learning GAN Series III)

Original link: https://arxiv.org/abs/1701.07875

Introduction

Core idea: Introduce "Wasserstein distance " to improve the optimization details of GAN to stabilize the training process of GAN . The method is very simple: by removing the sigmoid in the discriminator and the log operation of LOSS in the generator and discriminator, and each update the updated value is limited to a constant size C, and try not to use momentum-based optimization algorithm.

Intuitively speaking, it is unreasonable for the original GAN ​​to use JSD as LOSS. Its distribution, as shown in the right figure, will only be equal to zero when the two distributions are fitted, and this point is not continuous with other points. Yes, it is very difficult to optimize the model using stochastic gradient descent. And the author of this article removes the sigmoid layer, which is equivalent to reducing the output result using nonlinear compression to 01, so that the output result will be more continuous; but because the sigmoid is not used, the value range of the output result changes greatly, when the output value If the LOSS is too large, the discriminator will quickly converge to the extreme value, which is not conducive to the generator’s fitting of the original data. Therefore, the author proposes the upper limit of the model parameter update, the purpose is to slow down the generation The convergence speed of the generator and the discriminator prevents the model from collapsing; in the end, the author also proposed not to use a momentum-based optimizer. The intention is to make the direction of model optimization more random, and to ensure that the data fitting distribution established by the generator can be more completely fitted. Original data distribution.

LOSS mutation problem

     

From the above formula, it can be seen intuitively that JS, KL and sigmoid all have abrupt values ​​in the distribution fitting, and most of the time the constant function cannot describe the similarity of the two distributions, but W uses a linear function as loss without abrupt value The existence and linearity of, can only describe the distance of two distributions during optimization.

PS: A good LOSS function should be able to return a penalty term corresponding to the level of the output result of the model to make the model clear the direction of optimization.

However, the feedback penalty term of the LOSS function designed in the original GAN ​​on the similarity of the two distributions does not have the ability to reflect its degree, and it is precisely because of this that the improvement of Wasserstein GAN greatly reduces the difficulty of GAN network training.

Model establishment

Returning to the original GAN ​​network idea, we hope that a pair of generators and discriminators can promote mutual progress in the process of confrontation learning, then we must ensure that the strength of the generator and the discriminator is equivalent. However, in the optimization process, sometimes the generator trains fast, sometimes the discriminator trains fast, which will make the generator beat the discriminator or the discriminator beat the generator, resulting in training failure. In practice, however, discriminators with clear labels are often trained faster.

Based on this, this article proposes an update suppression method, which suppresses the degree of network update each time within a fixed range. Look at the design of LOSS in detail:

Wasserstein distance is also called Earth-Mover (EM) distance

               

γ is the joint distribution of Pr and Pg. The entire formula describes the similarity of the distribution of Pr and Pg. However, because the joint distribution of γ is difficult to solve, the author makes equivalent transformations to the formula (you can find the process in the appendix of the original paper) ):

Where K is the Lipschitz constant, which simply means to ensure that the function f(x) satisfies the K of the following formula. K describes the maximum derivative of f(x):

                              

Then W is the maximum upper bound of the difference between the distribution of Pr and Pg and divided by K. If the family of functions w is used to replace f, then

This formula means that we can actually use deep learning to search for a suitable f in fw to minimize W(Pg,Pr), that is, Pr and Pg fit together. In practice, we will not solve for K, but replace it with a small value around 0. The final LOSS function is simplified to:

                         

Among them, fw is the discriminator D, Pr is the true distribution, and Pg is the generated distribution.

Algorithm flow

The above picture intuitively shows the superiority of WGAN's LOSS in optimization.

Code and practice

Reference link ( https://github.com/WingsofFAN/PyTorch-GAN/blob/master/implementations/wgan/wgan.py )

import argparse
import os
import numpy as np
import math
import sys

import torchvision.transforms as transforms
from torchvision.utils import save_image

from torch.utils.data import DataLoader
from torchvision import datasets
from torch.autograd import Variable

import torch.nn as nn
import torch.nn.functional as F
import torch

os.makedirs("images", exist_ok=True)

parser = argparse.ArgumentParser()
parser.add_argument("--n_epochs", type=int, default=200, help="number of epochs of training")
parser.add_argument("--batch_size", type=int, default=64, help="size of the batches")
parser.add_argument("--lr", type=float, default=0.00005, help="learning rate")
parser.add_argument("--n_cpu", type=int, default=8, help="number of cpu threads to use during batch generation")
parser.add_argument("--latent_dim", type=int, default=100, help="dimensionality of the latent space")
parser.add_argument("--img_size", type=int, default=28, help="size of each image dimension")
parser.add_argument("--channels", type=int, default=1, help="number of image channels")
parser.add_argument("--n_critic", type=int, default=5, help="number of training steps for discriminator per iter")
parser.add_argument("--clip_value", type=float, default=0.01, help="lower and upper clip value for disc. weights")
parser.add_argument("--sample_interval", type=int, default=400, help="interval betwen image samples")
opt = parser.parse_args()
print(opt)

img_shape = (opt.channels, opt.img_size, opt.img_size)

cuda = True if torch.cuda.is_available() else False


class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()

        def block(in_feat, out_feat, normalize=True):
            layers = [nn.Linear(in_feat, out_feat)]
            if normalize:
                layers.append(nn.BatchNorm1d(out_feat, 0.8))
                
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return layers

        self.model = nn.Sequential(
            *block(opt.latent_dim, 128, normalize=False),
            *block(128, 256),
            *block(256, 512),
            *block(512, 1024),
            nn.Linear(1024, int(np.prod(img_shape))),
            nn.Tanh()
        )

    def forward(self, z):
        img = self.model(z)
        img = img.view(img.shape[0], *img_shape)
        return img


class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()

        self.model = nn.Sequential(
            nn.Linear(int(np.prod(img_shape)), 512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(256, 1),
            #去除掉sigmoid
        )

    def forward(self, img):
        img_flat = img.view(img.shape[0], -1)
        validity = self.model(img_flat)
        return validity


# Initialize generator and discriminator
generator = Generator()
discriminator = Discriminator()

if cuda:
    generator.cuda()
    discriminator.cuda()

# Configure data loader
os.makedirs("../../data/mnist", exist_ok=True)
dataloader = torch.utils.data.DataLoader(
    datasets.MNIST(
        "../../data/mnist",
        train=True,
        download=True,
        transform=transforms.Compose([transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]),
    ),
    batch_size=opt.batch_size,
    shuffle=True,
)

# Optimizers
optimizer_G = torch.optim.RMSprop(generator.parameters(), lr=opt.lr)
optimizer_D = torch.optim.RMSprop(discriminator.parameters(), lr=opt.lr)

Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor

# ----------
#  Training
# ----------

batches_done = 0
for epoch in range(opt.n_epochs):

    for i, (imgs, _) in enumerate(dataloader):

        # Configure input
        real_imgs = Variable(imgs.type(Tensor))

        # ---------------------
        #  Train Discriminator
        # ---------------------

        optimizer_D.zero_grad()

        # Sample noise as generator input
        z = Variable(Tensor(np.random.normal(0, 1, (imgs.shape[0], opt.latent_dim))))

        # Generate a batch of images
        fake_imgs = generator(z).detach()
        # Adversarial loss
        loss_D = -torch.mean(discriminator(real_imgs)) + torch.mean(discriminator(fake_imgs))
        # 求fake_imgs和real_imgs的判别loss

        loss_D.backward()
        optimizer_D.step()

        # Clip weights of discriminator
        #抑制判别器的更新程度
        for p in discriminator.parameters():
            p.data.clamp_(-opt.clip_value, opt.clip_value)

        # Train the generator every n_critic iterations
        if i % opt.n_critic == 0:
            #因为对D使用了更新抑制
            #所以原本更新过快的D,反而滞后了
            #所以使用隔几个批次,才训练一次G

            # -----------------
            #  Train Generator
            # -----------------

            optimizer_G.zero_grad()

            # Generate a batch of images
            gen_imgs = generator(z)
            # Adversarial loss
            
            loss_G = -torch.mean(discriminator(gen_imgs))

            loss_G.backward()
            optimizer_G.step()

            print(
                "[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f]"
                % (epoch, opt.n_epochs, batches_done % len(dataloader), len(dataloader), loss_D.item(), loss_G.item())
            )

        if batches_done % opt.sample_interval == 0:
            save_image(gen_imgs.data[:25], "images/%d.png" % batches_done, nrow=5, normalize=True)
        batches_done += 1

minist test

In the test, I found that the training process of WGAN is actually very slow. This is a side effect of using gradient clipping. As for how to solve this problem, let's listen to the next decomposition.

 

 

 

Guess you like

Origin blog.csdn.net/fan1102958151/article/details/106326650