W series GAN (Xiaobai learns GAN series four)

PS: After learning WGAN, I found that there are still some imperfections in the original WGAN, and there are many later people who are filling in the pit. This article selects two of the main ones for research. Before reading this article, please read Wasserstein GAN (Little White Learning GAN Series 3)

Original link: WGAN-GP ( https://arxiv.org/abs/1704.00028 )

                  WGAN-DIV(https://arxiv.org/abs/1712.01026)

Introduction

"W distance" is proposed in WGAN:

                                      

Here Pr(x) is the distribution of real samples, Pg(y) is the fake distribution, ‖x−y‖ is the transmission cost; and γ∈Π(Pr(x),Pg(y)) means: γ It is an arbitrary binary distribution about the next x and y, and its marginal distribution is Pr(x) and Pg(y). Intuitively, γ describes a transportation plan, and ‖x−y‖ is the transportation cost. Wc(Pr(x),Pg(y)) means to find the cost corresponding to the lowest cost transportation plan as the distribution measure.

But because the joint distribution γ is difficult to solve, it is transformed into a dual problem:

                                    

Where f(x) is a scalar function, which ||f||_{L}is the Lipschitz norm:

                                                                     ||f||_{L} = \max_{x\neq y}\frac{f(x)-f(y)}{||x-y||}

To be satisfied                                                      |f(x)-f(y)| \leq {||x-y||}

Then the training process of WGAN can be expressed as\arg \min_{G} \max_{D,||D||_{L}\leq 1} \mathbb{E}_{x\sim P_{r}(x)}[D(x)]-\mathbb{E}_{z\sim P_{g}(z)}[D(G(z))]

The core question: how to guarantee ||D||_{L}\leq 1it?

Solutions:

  1. Weight clipping , that is, the method in the original WGAN, after each step of gradient descent when optimizing the discriminator, the absolute value of the parameters of the discriminator is clipped to no more than a fixed constant. But this method is very rude, directly limiting the discriminator optimization step size, not only slow training, but also easy to fall into the local optimum.
  2. Add a penalty term and construct a reasonable penalty term to make the ||D||_{L}\leq 1conditions hold at any time .
  3. The essence of spectrum normalization is to suppress the magnitude of the gradient descent update, but it is added when constructing the LOSS calculation to ensure that the calculated LOSS is limited.

Improved Training of Wasserstein GANs

Core idea: Since the L constraint must be satisfied at all times, then simply construct the L constraint as a penalty item and put it in the LOSS.

       

Even the introduction of a penalty term L constraint is a joint distribution on, that the above equation \mathbb{P}_{\hat{x}}is difficult to solve, so this paper presents a tricky way to solve this problem, namely the random sampling after the real sample and the random sample mixing Simulate the process of sampling in a jointly distributed sample . Among them \lambda= 10, is a coefficient derived from experience.

Wasserstein Divergence for GANs

The core idea: "W divergence" is proposed to replace "W distance", then the corresponding restriction conditions have also changed, and the penalty added to LOSS is different.

"W divergence" removes the L constraint, but retains the property that "W distance" can describe the similarity of the distribution of Pr and Pg. The specific proof can be studied in the original paper. Then look directly after using "W divergence", the training process of the network is as follows:

In simple terms, when k=1 and p=2, it is only a constant difference from WGAN-GP. In the above formula, the author points out through experiments that when K=2 and P=6 , the effect is best.

Code and practice

Reference link ( https://github.com/WingsofFAN/PyTorch-GAN )

import argparse
import os
import numpy as np
import math
import sys

import torchvision.transforms as transforms
from torchvision.utils import save_image

from torch.utils.data import DataLoader
from torchvision import datasets
from torch.autograd import Variable

import torch.nn as nn
import torch.nn.functional as F
import torch.autograd as autograd
import torch

os.makedirs("images", exist_ok=True)

parser = argparse.ArgumentParser()
parser.add_argument("--n_epochs", type=int, default=200, help="number of epochs of training")
parser.add_argument("--batch_size", type=int, default=64, help="size of the batches")
parser.add_argument("--lr", type=float, default=0.0002, help="adam: learning rate")
parser.add_argument("--b1", type=float, default=0.5, help="adam: decay of first order momentum of gradient")
parser.add_argument("--b2", type=float, default=0.999, help="adam: decay of first order momentum of gradient")
parser.add_argument("--n_cpu", type=int, default=8, help="number of cpu threads to use during batch generation")
parser.add_argument("--latent_dim", type=int, default=100, help="dimensionality of the latent space")
parser.add_argument("--img_size", type=int, default=28, help="size of each image dimension")
parser.add_argument("--channels", type=int, default=1, help="number of image channels")
parser.add_argument("--n_critic", type=int, default=5, help="number of training steps for discriminator per iter")
parser.add_argument("--clip_value", type=float, default=0.01, help="lower and upper clip value for disc. weights")
parser.add_argument("--sample_interval", type=int, default=400, help="interval betwen image samples")
opt = parser.parse_args()
print(opt)

img_shape = (opt.channels, opt.img_size, opt.img_size)

cuda = True if torch.cuda.is_available() else False


class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()

        def block(in_feat, out_feat, normalize=True):
            layers = [nn.Linear(in_feat, out_feat)]
            if normalize:
                layers.append(nn.BatchNorm1d(out_feat, 0.8))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return layers

        self.model = nn.Sequential(
            *block(opt.latent_dim, 128, normalize=False),
            *block(128, 256),
            *block(256, 512),
            *block(512, 1024),
            nn.Linear(1024, int(np.prod(img_shape))),
            nn.Tanh()
        )

    def forward(self, z):
        img = self.model(z)
        img = img.view(img.shape[0], *img_shape)
        return img


class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()

        self.model = nn.Sequential(
            nn.Linear(int(np.prod(img_shape)), 512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(256, 1),
        )

    def forward(self, img):
        img_flat = img.view(img.shape[0], -1)
        validity = self.model(img_flat)
        return validity


# Loss weight for gradient penalty
lambda_gp = 10

# Initialize generator and discriminator
generator = Generator()
discriminator = Discriminator()

if cuda:
    generator.cuda()
    discriminator.cuda()

# Configure data loader
os.makedirs("../../data/mnist", exist_ok=True)
dataloader = torch.utils.data.DataLoader(
    datasets.MNIST(
        "../../data/mnist",
        train=True,
        download=True,
        transform=transforms.Compose(
            [transforms.Resize(opt.img_size), transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]
        ),
    ),
    batch_size=opt.batch_size,
    shuffle=True,
)

# Optimizers
optimizer_G = torch.optim.Adam(generator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))

Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor


def compute_gradient_penalty_div(real_imgs,real_validity,fake_imgs,fake_validity):
    real_grad_out = Variable(Tensor(real_imgs.size(0), 1).fill_(1.0), requires_grad=False)
    real_grad = autograd.grad(
        real_validity, real_imgs, real_grad_out, create_graph=True, retain_graph=True, only_inputs=True
    )[0]
    real_grad_norm = real_grad.view(real_grad.size(0), -1).pow(2).sum(1) ** (p / 2)

    fake_grad_out = Variable(Tensor(fake_imgs.size(0), 1).fill_(1.0), requires_grad=False)
    fake_grad = autograd.grad(
        fake_validity, fake_imgs, fake_grad_out, create_graph=True, retain_graph=True, only_inputs=True
    )[0]
    fake_grad_norm = fake_grad.view(fake_grad.size(0), -1).pow(2).sum(1) ** (p / 2)

    div_gp = torch.mean(real_grad_norm + fake_grad_norm) * k / 2
    
    return div_gp



def compute_gradient_penalty_gp(D, real_samples, fake_samples):
    """Calculates the gradient penalty loss for WGAN GP"""
    # Random weight term for interpolation between real and fake samples
    alpha = Tensor(np.random.random((real_samples.size(0), 1, 1, 1)))
    # Get random interpolation between real and fake samples
    interpolates = (alpha * real_samples + ((1 - alpha) * fake_samples)).requires_grad_(True)
    d_interpolates = D(interpolates)
    fake = Variable(Tensor(real_samples.shape[0], 1).fill_(1.0), requires_grad=False)
    # Get gradient w.r.t. interpolates
    gradients = autograd.grad(
        outputs=d_interpolates,
        inputs=interpolates,
        grad_outputs=fake,
        create_graph=True,
        retain_graph=True,
        only_inputs=True,
    )[0]
    gradients = gradients.view(gradients.size(0), -1)
    gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean()
    return gradient_penalty




# ----------
#  Training
# ----------

batches_done = 0
for epoch in range(opt.n_epochs):
    for i, (imgs, _) in enumerate(dataloader):

        # Configure input
        real_imgs = Variable(imgs.type(Tensor))

        # ---------------------
        #  Train Discriminator
        # ---------------------

        optimizer_D.zero_grad()

        # Sample noise as generator input
        z = Variable(Tensor(np.random.normal(0, 1, (imgs.shape[0], opt.latent_dim))))

        # Generate a batch of images
        fake_imgs = generator(z)

        # Real images
        real_validity = discriminator(real_imgs)
        # Fake images
        fake_validity = discriminator(fake_imgs)
        
        #WGAN_GP的LOSS计算
        # Gradient penalty
        gradient_penalty = compute_gradient_penalty_gp(discriminator, real_imgs.data, fake_imgs.data)
        # Adversarial loss
        d_loss = -torch.mean(real_validity) + torch.mean(fake_validity) + lambda_gp * gradient_penalty

#--------------------------------------------------------------------------------
        #WGAN_div的LOSS计算
        # Compute W-div gradient penalty
        # div_gp = compute_gradient_penalty_div(real_imgs,real_validity,fake_imgs,fake_validity)

        # Adversarial loss
        # d_loss = -torch.mean(real_validity) + torch.mean(fake_validity) + div_gp
        
#---------------------------------------------------------------------------------


        d_loss.backward()
        optimizer_D.step()

        optimizer_G.zero_grad()

        # Train the generator every n_critic steps
        if i % opt.n_critic == 0:

            # -----------------
            #  Train Generator
            # -----------------

            # Generate a batch of images
            fake_imgs = generator(z)
            # Loss measures generator's ability to fool the discriminator
            # Train on fake images
            fake_validity = discriminator(fake_imgs)
            g_loss = -torch.mean(fake_validity)

            g_loss.backward()
            optimizer_G.step()

            print(
                "[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f]"
                % (epoch, opt.n_epochs, i, len(dataloader), d_loss.item(), g_loss.item())
            )

            if batches_done % opt.sample_interval == 0:
                save_image(fake_imgs.data[:25], "images/%d.png" % batches_done, nrow=5, normalize=True)

            batches_done += opt.n_critic

minist test

  WGAN_GP   

        WGAN                           WGAN_GP                    WGAN_div

From the above results, the improvement of the WGAN series is very successful, and from the perspective of the improvement trend from WGAN, to WGAN_GP, and then to WGAN_div, it is actually gradually relaxing the restrictions on the " Wasserstein distance", although they are all based on " The loss function constructed by Wasserstein distance, but in fact, in the process of improvement, the “Wasserstein distance” is replaced with a relatively weaker constraint LOSS to optimize the discriminator. What is thought-provoking is that the relaxation of conditions not only did not make the pattern search difficult, but made the pattern establishment faster. From the MINIST data tested above, the three networks are optimized with the same learning rate and number of iterations. WGAN In the end, it failed to generate enough good data. WGAN_GP iterated to four-fifths of the total number of times to generate better data, and WGAN_div iterated to three-fifths of the total number of times to generate better data.

When consulting the literature, I found an article that also verified this phenomenon. " How Well Do WGANs Estimate the Wasserstein Metric " conducted experiments and analyses on many GAN network architectures based on "Wasserstein distance" and reached corresponding conclusions. : That is, it is not that the closer the LOSS function to the "Wasserstein distance", the better the performance .

The appearance of this phenomenon is actually not unexpected. The following are some personal conjectures, which may be incorrect. I hope you can advise. We all know that the data sample distribution can only approximate the true distribution of the data, and what we actually use GAN network to learn is the data sample distribution, and there is still a deviation from the true distribution of the data. "Wasserstein distance" can well describe the distance between two distributions, allowing the generator to learn the distribution of data samples, but what we actually expect is the true distribution of the data.

It may be a bit convoluted, with an example to explain: we hope to use GAN to generate handwritten Arabic numerals, in fact, we hope that the generation network learns a mapping from Gaussian distribution to handwritten Arabic numerals data distribution; but in reality we cannot Obtain samples of all handwritten Arabic numerals, so we use the MINIST dataset as a substitute; therefore, when optimizing the discriminator, we replaced the goal of "data distribution closer to handwritten Arabic numerals" with "data distribution closer to MINIST" "; then when the generator generates a sample that is relatively close to "data distribution of handwritten Arabic numerals" but relatively "data distribution of MINIST", the strict "Wasserstein distance" will add "penalties" to it to keep it away" The data distribution of handwritten Arabic numerals is close to the data distribution of MINIST, which causes the final result to decrease in performance; however, we relax the “Wasserstein distance” so that the distribution learned by the generator is close to the data distribution of MINIST to a certain range. Rather than aiming at complete convergence, when the generator learns a distribution close to the "MINIST data distribution", it may be that this distribution is closer to the "data distribution of handwritten Arabic numerals", so the effect is better. As shown in the figure below, after the conditions are relaxed, the "MINIST data distribution" may not fit well, but we may fit the "data distribution of handwritten Arabic numerals".

 

Guess you like

Origin blog.csdn.net/fan1102958151/article/details/106341705