Original link: https://arxiv.org/pdf/1604.07379.pdf
Introduction
Background: From "Conditinal GAN" to "InfoGAN" , we are all committed to solving a problem that is how to control the GAN network to generate the data we want through the paradigm of human control, but most of the previous methods are aimed at models The input and output of the article. And this article is by conceiving a dual (or multiple) GAN architecture, through the mutual constraints of different domains to achieve control of the style of generated data.
The core idea: using a pair of GANs to share weights, so that the two data space domains simultaneously constrain the generated data.
Judging from the above figure, if we look at it from the middle, there are two independent GANs, but the difference is that the two GANs share the network weights between the first few layers of the generator and the next few layers of the generator. Why do we share network weights? It can be understood perceptually that we apply the GAN network to two tasks, so GAN will be constrained by the two tasks, and the more constraints, the easier it is for us to control the optimization direction of GAN.
basic structure
Builder
We split each layer of the generator network to get the above formula. Since the two generators map from the same random variable to different data spaces, the shared network layer can only be close to the noise z On the input side, the layer with the smaller m.
Discriminator
We split each layer of the discriminator network to get the above formula. Since the two discriminators map from different data spaces to the true and false judgment results, the shared network layer can only be close to the result The output side, that is, the layer with larger n.
LOSS
Expressed with a simple diagram:
The two large rectangles in the figure above are the constraints formed by the joint distribution domain.
Code and practice results
Reference link: https://github.com/WingsofFAN/PyTorch-GAN/blob/master/implementations/cogan/cogan.py
import argparse
import os
import numpy as np
import math
import scipy
import itertools
import mnistm
import torchvision.transforms as transforms
from torchvision.utils import save_image
from torch.utils.data import DataLoader
from torchvision import datasets
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F
import torch
os.makedirs("images", exist_ok=True)
parser = argparse.ArgumentParser()
parser.add_argument("--n_epochs", type=int, default=200, help="number of epochs of training")
parser.add_argument("--batch_size", type=int, default=32, help="size of the batches")
parser.add_argument("--lr", type=float, default=0.0002, help="adam: learning rate")
parser.add_argument("--b1", type=float, default=0.5, help="adam: decay of first order momentum of gradient")
parser.add_argument("--b2", type=float, default=0.999, help="adam: decay of first order momentum of gradient")
parser.add_argument("--n_cpu", type=int, default=8, help="number of cpu threads to use during batch generation")
parser.add_argument("--latent_dim", type=int, default=100, help="dimensionality of the latent space")
parser.add_argument("--img_size", type=int, default=32, help="size of each image dimension")
parser.add_argument("--channels", type=int, default=3, help="number of image channels")
parser.add_argument("--sample_interval", type=int, default=400, help="interval betwen image samples")
opt = parser.parse_args()
print(opt)
img_shape = (opt.channels, opt.img_size, opt.img_size)
cuda = True if torch.cuda.is_available() else False
def weights_init_normal(m):
classname = m.__class__.__name__
if classname.find("Linear") != -1:
torch.nn.init.normal_(m.weight.data, 0.0, 0.02)
elif classname.find("BatchNorm") != -1:
torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
torch.nn.init.constant_(m.bias.data, 0.0)
class CoupledGenerators(nn.Module):
def __init__(self):
super(CoupledGenerators, self).__init__()
self.init_size = opt.img_size // 4
self.fc = nn.Sequential(nn.Linear(opt.latent_dim, 128 * self.init_size ** 2))
self.shared_conv = nn.Sequential(
nn.BatchNorm2d(128),
nn.Upsample(scale_factor=2),
nn.Conv2d(128, 128, 3, stride=1, padding=1),
nn.BatchNorm2d(128, 0.8),
nn.LeakyReLU(0.2, inplace=True),
nn.Upsample(scale_factor=2),
)
self.G1 = nn.Sequential(
nn.Conv2d(128, 64, 3, stride=1, padding=1),
nn.BatchNorm2d(64, 0.8),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(64, opt.channels, 3, stride=1, padding=1),
nn.Tanh(),
)
self.G2 = nn.Sequential(
nn.Conv2d(128, 64, 3, stride=1, padding=1),
nn.BatchNorm2d(64, 0.8),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(64, opt.channels, 3, stride=1, padding=1),
nn.Tanh(),
)
def forward(self, noise):
out = self.fc(noise)
out = out.view(out.shape[0], 128, self.init_size, self.init_size)
img_emb = self.shared_conv(out)
img1 = self.G1(img_emb)
img2 = self.G2(img_emb)
return img1, img2
class CoupledDiscriminators(nn.Module):
def __init__(self):
super(CoupledDiscriminators, self).__init__()
def discriminator_block(in_filters, out_filters, bn=True):
block = [nn.Conv2d(in_filters, out_filters, 3, 2, 1)]
if bn:
block.append(nn.BatchNorm2d(out_filters, 0.8))
block.extend([nn.LeakyReLU(0.2, inplace=True), nn.Dropout2d(0.25)])
return block
self.shared_conv = nn.Sequential(
*discriminator_block(opt.channels, 16, bn=False),
*discriminator_block(16, 32),
*discriminator_block(32, 64),
*discriminator_block(64, 128),
)
# The height and width of downsampled image
ds_size = opt.img_size // 2 ** 4
self.D1 = nn.Linear(128 * ds_size ** 2, 1)
self.D2 = nn.Linear(128 * ds_size ** 2, 1)
def forward(self, img1, img2):
# Determine validity of first image
out = self.shared_conv(img1)
out = out.view(out.shape[0], -1)
validity1 = self.D1(out)
# Determine validity of second image
out = self.shared_conv(img2)
out = out.view(out.shape[0], -1)
validity2 = self.D2(out)
return validity1, validity2
# Loss function
adversarial_loss = torch.nn.MSELoss()
# Initialize models
coupled_generators = CoupledGenerators()
coupled_discriminators = CoupledDiscriminators()
if cuda:
coupled_generators.cuda()
coupled_discriminators.cuda()
# Initialize weights
coupled_generators.apply(weights_init_normal)
coupled_discriminators.apply(weights_init_normal)
# Configure data loader
os.makedirs("../../data/mnist", exist_ok=True)
dataloader1 = torch.utils.data.DataLoader(
datasets.MNIST(
"../../data/mnist",
train=True,
download=True,
transform=transforms.Compose(
[transforms.Resize(opt.img_size), transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]
),
),
batch_size=opt.batch_size,
shuffle=True,
)
os.makedirs("../../data/mnistm", exist_ok=True)
dataloader2 = torch.utils.data.DataLoader(
mnistm.MNISTM(
"../../data/mnistm",
train=True,
download=True,
transform=transforms.Compose(
[
transforms.Resize(opt.img_size),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
]
),
),
batch_size=opt.batch_size,
shuffle=True,
)
# Optimizers
optimizer_G = torch.optim.Adam(coupled_generators.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
optimizer_D = torch.optim.Adam(coupled_discriminators.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor
# ----------
# Training
# ----------
for epoch in range(opt.n_epochs):
for i, ((imgs1, _), (imgs2, _)) in enumerate(zip(dataloader1, dataloader2)):
batch_size = imgs1.shape[0]
# Adversarial ground truths
valid = Variable(Tensor(batch_size, 1).fill_(1.0), requires_grad=False)
fake = Variable(Tensor(batch_size, 1).fill_(0.0), requires_grad=False)
# Configure input
imgs1 = Variable(imgs1.type(Tensor).expand(imgs1.size(0), 3, opt.img_size, opt.img_size))
imgs2 = Variable(imgs2.type(Tensor))
# ------------------
# Train Generators
# ------------------
optimizer_G.zero_grad()
# Sample noise as generator input
z = Variable(Tensor(np.random.normal(0, 1, (batch_size, opt.latent_dim))))
# Generate a batch of images
gen_imgs1, gen_imgs2 = coupled_generators(z)
# Determine validity of generated images
validity1, validity2 = coupled_discriminators(gen_imgs1, gen_imgs2)
g_loss = (adversarial_loss(validity1, valid) + adversarial_loss(validity2, valid)) / 2
g_loss.backward()
optimizer_G.step()
# ----------------------
# Train Discriminators
# ----------------------
optimizer_D.zero_grad()
# Determine validity of real and generated images
validity1_real, validity2_real = coupled_discriminators(imgs1, imgs2)
validity1_fake, validity2_fake = coupled_discriminators(gen_imgs1.detach(), gen_imgs2.detach())
#真实图片输入对应两个loss
#生成图片输入对应两个loss
d_loss = (
adversarial_loss(validity1_real, valid)
+ adversarial_loss(validity1_fake, fake)
+ adversarial_loss(validity2_real, valid)
+ adversarial_loss(validity2_fake, fake)
) / 4
d_loss.backward()
optimizer_D.step()
print(
"[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f]"
% (epoch, opt.n_epochs, i, len(dataloader1), d_loss.item(), g_loss.item())
)
batches_done = epoch * len(dataloader1) + i
if batches_done % opt.sample_interval == 0:
gen_imgs = torch.cat((gen_imgs1.data, gen_imgs2.data), 0)
save_image(gen_imgs, "images/%d.png" % batches_done, nrow=8, normalize=True)
mnist and mnistm test results
mnist mnistm