PS: After learning WGAN, I found that there are still some imperfections in the original WGAN, and there are many later people who are filling in the pit. This article selects two of the main ones for research. Before reading this article, please read Wasserstein GAN (Little White Learning GAN Series 3)
Original link: WGAN-GP ( https://arxiv.org/abs/1704.00028 )
WGAN-DIV(https://arxiv.org/abs/1712.01026)
Introduction
"W distance" is proposed in WGAN:
Here Pr(x) is the distribution of real samples, Pg(y) is the fake distribution, ‖x−y‖ is the transmission cost; and γ∈Π(Pr(x),Pg(y)) means: γ It is an arbitrary binary distribution about the next x and y, and its marginal distribution is Pr(x) and Pg(y). Intuitively, γ describes a transportation plan, and ‖x−y‖ is the transportation cost. Wc(Pr(x),Pg(y)) means to find the cost corresponding to the lowest cost transportation plan as the distribution measure.
But because the joint distribution γ is difficult to solve, it is transformed into a dual problem:
Where f(x) is a scalar function, which is the Lipschitz norm:
To be satisfied
Then the training process of WGAN can be expressed as
The core question: how to guarantee it?
Solutions:
- Weight clipping , that is, the method in the original WGAN, after each step of gradient descent when optimizing the discriminator, the absolute value of the parameters of the discriminator is clipped to no more than a fixed constant. But this method is very rude, directly limiting the discriminator optimization step size, not only slow training, but also easy to fall into the local optimum.
- Add a penalty term and construct a reasonable penalty term to make the conditions hold at any time .
- The essence of spectrum normalization is to suppress the magnitude of the gradient descent update, but it is added when constructing the LOSS calculation to ensure that the calculated LOSS is limited.
Improved Training of Wasserstein GANs
Core idea: Since the L constraint must be satisfied at all times, then simply construct the L constraint as a penalty item and put it in the LOSS.
Even the introduction of a penalty term L constraint is a joint distribution on, that the above equation is difficult to solve, so this paper presents a tricky way to solve this problem, namely the random sampling after the real sample and the random sample mixing Simulate the process of sampling in a jointly distributed sample . Among them = 10, is a coefficient derived from experience.
Wasserstein Divergence for GANs
The core idea: "W divergence" is proposed to replace "W distance", then the corresponding restriction conditions have also changed, and the penalty added to LOSS is different.
"W divergence" removes the L constraint, but retains the property that "W distance" can describe the similarity of the distribution of Pr and Pg. The specific proof can be studied in the original paper. Then look directly after using "W divergence", the training process of the network is as follows:
In simple terms, when k=1 and p=2, it is only a constant difference from WGAN-GP. In the above formula, the author points out through experiments that when K=2 and P=6 , the effect is best.
Code and practice
Reference link ( https://github.com/WingsofFAN/PyTorch-GAN )
import argparse
import os
import numpy as np
import math
import sys
import torchvision.transforms as transforms
from torchvision.utils import save_image
from torch.utils.data import DataLoader
from torchvision import datasets
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F
import torch.autograd as autograd
import torch
os.makedirs("images", exist_ok=True)
parser = argparse.ArgumentParser()
parser.add_argument("--n_epochs", type=int, default=200, help="number of epochs of training")
parser.add_argument("--batch_size", type=int, default=64, help="size of the batches")
parser.add_argument("--lr", type=float, default=0.0002, help="adam: learning rate")
parser.add_argument("--b1", type=float, default=0.5, help="adam: decay of first order momentum of gradient")
parser.add_argument("--b2", type=float, default=0.999, help="adam: decay of first order momentum of gradient")
parser.add_argument("--n_cpu", type=int, default=8, help="number of cpu threads to use during batch generation")
parser.add_argument("--latent_dim", type=int, default=100, help="dimensionality of the latent space")
parser.add_argument("--img_size", type=int, default=28, help="size of each image dimension")
parser.add_argument("--channels", type=int, default=1, help="number of image channels")
parser.add_argument("--n_critic", type=int, default=5, help="number of training steps for discriminator per iter")
parser.add_argument("--clip_value", type=float, default=0.01, help="lower and upper clip value for disc. weights")
parser.add_argument("--sample_interval", type=int, default=400, help="interval betwen image samples")
opt = parser.parse_args()
print(opt)
img_shape = (opt.channels, opt.img_size, opt.img_size)
cuda = True if torch.cuda.is_available() else False
class Generator(nn.Module):
def __init__(self):
super(Generator, self).__init__()
def block(in_feat, out_feat, normalize=True):
layers = [nn.Linear(in_feat, out_feat)]
if normalize:
layers.append(nn.BatchNorm1d(out_feat, 0.8))
layers.append(nn.LeakyReLU(0.2, inplace=True))
return layers
self.model = nn.Sequential(
*block(opt.latent_dim, 128, normalize=False),
*block(128, 256),
*block(256, 512),
*block(512, 1024),
nn.Linear(1024, int(np.prod(img_shape))),
nn.Tanh()
)
def forward(self, z):
img = self.model(z)
img = img.view(img.shape[0], *img_shape)
return img
class Discriminator(nn.Module):
def __init__(self):
super(Discriminator, self).__init__()
self.model = nn.Sequential(
nn.Linear(int(np.prod(img_shape)), 512),
nn.LeakyReLU(0.2, inplace=True),
nn.Linear(512, 256),
nn.LeakyReLU(0.2, inplace=True),
nn.Linear(256, 1),
)
def forward(self, img):
img_flat = img.view(img.shape[0], -1)
validity = self.model(img_flat)
return validity
# Loss weight for gradient penalty
lambda_gp = 10
# Initialize generator and discriminator
generator = Generator()
discriminator = Discriminator()
if cuda:
generator.cuda()
discriminator.cuda()
# Configure data loader
os.makedirs("../../data/mnist", exist_ok=True)
dataloader = torch.utils.data.DataLoader(
datasets.MNIST(
"../../data/mnist",
train=True,
download=True,
transform=transforms.Compose(
[transforms.Resize(opt.img_size), transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]
),
),
batch_size=opt.batch_size,
shuffle=True,
)
# Optimizers
optimizer_G = torch.optim.Adam(generator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor
def compute_gradient_penalty_div(real_imgs,real_validity,fake_imgs,fake_validity):
real_grad_out = Variable(Tensor(real_imgs.size(0), 1).fill_(1.0), requires_grad=False)
real_grad = autograd.grad(
real_validity, real_imgs, real_grad_out, create_graph=True, retain_graph=True, only_inputs=True
)[0]
real_grad_norm = real_grad.view(real_grad.size(0), -1).pow(2).sum(1) ** (p / 2)
fake_grad_out = Variable(Tensor(fake_imgs.size(0), 1).fill_(1.0), requires_grad=False)
fake_grad = autograd.grad(
fake_validity, fake_imgs, fake_grad_out, create_graph=True, retain_graph=True, only_inputs=True
)[0]
fake_grad_norm = fake_grad.view(fake_grad.size(0), -1).pow(2).sum(1) ** (p / 2)
div_gp = torch.mean(real_grad_norm + fake_grad_norm) * k / 2
return div_gp
def compute_gradient_penalty_gp(D, real_samples, fake_samples):
"""Calculates the gradient penalty loss for WGAN GP"""
# Random weight term for interpolation between real and fake samples
alpha = Tensor(np.random.random((real_samples.size(0), 1, 1, 1)))
# Get random interpolation between real and fake samples
interpolates = (alpha * real_samples + ((1 - alpha) * fake_samples)).requires_grad_(True)
d_interpolates = D(interpolates)
fake = Variable(Tensor(real_samples.shape[0], 1).fill_(1.0), requires_grad=False)
# Get gradient w.r.t. interpolates
gradients = autograd.grad(
outputs=d_interpolates,
inputs=interpolates,
grad_outputs=fake,
create_graph=True,
retain_graph=True,
only_inputs=True,
)[0]
gradients = gradients.view(gradients.size(0), -1)
gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean()
return gradient_penalty
# ----------
# Training
# ----------
batches_done = 0
for epoch in range(opt.n_epochs):
for i, (imgs, _) in enumerate(dataloader):
# Configure input
real_imgs = Variable(imgs.type(Tensor))
# ---------------------
# Train Discriminator
# ---------------------
optimizer_D.zero_grad()
# Sample noise as generator input
z = Variable(Tensor(np.random.normal(0, 1, (imgs.shape[0], opt.latent_dim))))
# Generate a batch of images
fake_imgs = generator(z)
# Real images
real_validity = discriminator(real_imgs)
# Fake images
fake_validity = discriminator(fake_imgs)
#WGAN_GP的LOSS计算
# Gradient penalty
gradient_penalty = compute_gradient_penalty_gp(discriminator, real_imgs.data, fake_imgs.data)
# Adversarial loss
d_loss = -torch.mean(real_validity) + torch.mean(fake_validity) + lambda_gp * gradient_penalty
#--------------------------------------------------------------------------------
#WGAN_div的LOSS计算
# Compute W-div gradient penalty
# div_gp = compute_gradient_penalty_div(real_imgs,real_validity,fake_imgs,fake_validity)
# Adversarial loss
# d_loss = -torch.mean(real_validity) + torch.mean(fake_validity) + div_gp
#---------------------------------------------------------------------------------
d_loss.backward()
optimizer_D.step()
optimizer_G.zero_grad()
# Train the generator every n_critic steps
if i % opt.n_critic == 0:
# -----------------
# Train Generator
# -----------------
# Generate a batch of images
fake_imgs = generator(z)
# Loss measures generator's ability to fool the discriminator
# Train on fake images
fake_validity = discriminator(fake_imgs)
g_loss = -torch.mean(fake_validity)
g_loss.backward()
optimizer_G.step()
print(
"[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f]"
% (epoch, opt.n_epochs, i, len(dataloader), d_loss.item(), g_loss.item())
)
if batches_done % opt.sample_interval == 0:
save_image(fake_imgs.data[:25], "images/%d.png" % batches_done, nrow=5, normalize=True)
batches_done += opt.n_critic
minist test
WGAN WGAN_GP WGAN_div
From the above results, the improvement of the WGAN series is very successful, and from the perspective of the improvement trend from WGAN, to WGAN_GP, and then to WGAN_div, it is actually gradually relaxing the restrictions on the " Wasserstein distance", although they are all based on " The loss function constructed by Wasserstein distance, but in fact, in the process of improvement, the “Wasserstein distance” is replaced with a relatively weaker constraint LOSS to optimize the discriminator. What is thought-provoking is that the relaxation of conditions not only did not make the pattern search difficult, but made the pattern establishment faster. From the MINIST data tested above, the three networks are optimized with the same learning rate and number of iterations. WGAN In the end, it failed to generate enough good data. WGAN_GP iterated to four-fifths of the total number of times to generate better data, and WGAN_div iterated to three-fifths of the total number of times to generate better data.
When consulting the literature, I found an article that also verified this phenomenon. " How Well Do WGANs Estimate the Wasserstein Metric " conducted experiments and analyses on many GAN network architectures based on "Wasserstein distance" and reached corresponding conclusions. : That is, it is not that the closer the LOSS function to the "Wasserstein distance", the better the performance .
The appearance of this phenomenon is actually not unexpected. The following are some personal conjectures, which may be incorrect. I hope you can advise. We all know that the data sample distribution can only approximate the true distribution of the data, and what we actually use GAN network to learn is the data sample distribution, and there is still a deviation from the true distribution of the data. "Wasserstein distance" can well describe the distance between two distributions, allowing the generator to learn the distribution of data samples, but what we actually expect is the true distribution of the data.
It may be a bit convoluted, with an example to explain: we hope to use GAN to generate handwritten Arabic numerals, in fact, we hope that the generation network learns a mapping from Gaussian distribution to handwritten Arabic numerals data distribution; but in reality we cannot Obtain samples of all handwritten Arabic numerals, so we use the MINIST dataset as a substitute; therefore, when optimizing the discriminator, we replaced the goal of "data distribution closer to handwritten Arabic numerals" with "data distribution closer to MINIST" "; then when the generator generates a sample that is relatively close to "data distribution of handwritten Arabic numerals" but relatively "data distribution of MINIST", the strict "Wasserstein distance" will add "penalties" to it to keep it away" The data distribution of handwritten Arabic numerals is close to the data distribution of MINIST, which causes the final result to decrease in performance; however, we relax the “Wasserstein distance” so that the distribution learned by the generator is close to the data distribution of MINIST to a certain range. Rather than aiming at complete convergence, when the generator learns a distribution close to the "MINIST data distribution", it may be that this distribution is closer to the "data distribution of handwritten Arabic numerals", so the effect is better. As shown in the figure below, after the conditions are relaxed, the "MINIST data distribution" may not fit well, but we may fit the "data distribution of handwritten Arabic numerals".