GAN —— 《Generative Adversarial Nets》

《Generative Adversarial Nets》

  • 生成式对抗网络;
  • 作者:lan Goodfellow;
  • 单位:加拿大蒙特利尔大学;
  • 发表会议及时间:NeurlPS(NIPS) 2014;

核心要点

  1. 提出了一个基于对抗的 新生成式模型,由一个生成器和一个判别器组成;
  2. 生成器的目标是学习到样本的数据分布,从而能生成样本欺骗判别器;判别器的目标是判断输入样本时生成/真实的概率;
  3. GAN模型等同于博弈论中的二人零和博弈;
  4. 对于任意的生成器和判别器,都存在一个独特的全局最优解;
  5. 在本文中,生成器和判别器都是由多层感知机实现,整个网络可以用反向传播算法来训练;
  6. 通过实验的定性与定量分析显示,GAN具备很大的潜力;

研究背景

1、零和博弈

  • 一方的收益必然意味着另一方的损失,博弈各方的收益和损失相加总和永远为“零”,双方不存在合作的可能;
  • 在零和博弈中,为了使己方达到最优解,所以把目标设为让对方的最大化收益最小化;

2、使用数据集

  • MNIST:手写数据集,源自NIST;28*28的灰度图,训练集60000张,测试集10000张;

  • TFD:The Toronro face dataset,人脸数据集;

  • CIFAR-10:32*32彩图,10个类别,每类6000张图,训练集50000张,测试集10000张;

3、GAN价值函数

价值函数
m i n G m a x D V ( D , G ) = E x ∼ p d a t a ( x ) [ l o g D ( x ) ] + E z ∼ p z ( z ) [ l o g ( 1 − D ( G ( z ) ) ) ] min_G max_D V(D,G)=E_{x\sim p_{data}(x)}[log D(x)]+E_{z\sim p_z(z)}[log(1-D(G(z)))] minGmaxDV(D,G)=Expdata(x)[logD(x)]+Ezpz(z)[log(1D(G(z)))]

  • d a t a data data:真实数据;
  • D D D:判别器,输出值为[0,1],代表输入来自真实数据的概率;
  • z z z:随机噪声;
  • G G G:生成器,输出为合成数据;

判别器 D D D的目的是最大化价值函数 V V V,对数函数log在底数大于1时为单调递增函数,最大化 V V V就是最大化 D ( x ) D(x) D(x) 1 − D ( G ( z ) ) 1-D(G(z)) 1D(G(z)),对于任意的x,都有 D ( x ) = 1 D(x)=1 D(x)=1,对于任意的 z z z都有 D ( G ( z ) ) = 0 D(G(z))=0 D(G(z))=0

生成器 G G G的目的是针对特定的 D D D,去最小化价值函数 V V V;最小化价值函数 V V V,就是最小化 D ( x ) D(x) D(x) 1 − D ( G ( z ) ) 1-D(G(z)) 1D(G(z));对于任意的 z z z,都有 D ( G ( z ) ) = 1 D(G(z))=1 D(G(z))=1

训练小trick

  • 在开始训练的时候,生成器 G G G的性能较差, D ( G ( z ) ) D(G(z)) D(G(z))接近于0,此时价值函数中的 l o g ( 1 − D ( G ( z ) ) ) log(1-D(G(z))) log(1D(G(z)))的梯度值较小,而 l o g ( D ( G ( z ) ) ) log(D(G(z))) log(D(G(z)))的梯度值较大,所以可以把生成器 G G G的目标改为最大化 l o g D ( G ( z ) ) logD(G(z)) logD(G(z)),这样可以在早期学习中提供更强的梯度。

4、训练流程

  • 使用mini-batch梯度下降(带momentum);
  • 训练k次判别器(本论文实验中k=1);
  • 训练1次生成器;

在这里插入图片描述
根据伪代码可以知道,对应两个神经网络模型——生成器 G G G和判别器 D D D,首先会固定生成器 G G G的参数,使用生成器 G G G生成的数据和真实的数据训练判别器 D D D,训练k次判别器 D D D后,固定判别器 D D D的参数,训练生成器 G G G

理想情况下,判别器的最优解为: D G ∗ ( x ) = P d a t a ( x ) P d a t a ( x ) + P g ( x ) D^*_{G}(x)=\frac{P_{data}(x)}{P_{data}(x)+P_g(x)} DG(x)=Pdata(x)+Pg(x)Pdata(x)判别器取得最优解时,生成器的最优解为: P g = P d a t a P_g=P_{data} Pg=Pdata此时价值函数的值为 C ∗ = − l o g ( 4 ) C^*=-log(4) C=log(4)

模型优劣势

缺点:

  • 没有显式表示的 P g ( x ) P_g(x) Pg(x)
  • 必须同步训练G和D,可能会发生模式崩溃;

优点:

  • 不使用马尔科夫链,在学习过程中不需要推理;
  • 可以将多种函数合并到模型中;
  • 可以表示非常尖锐、甚至退化的分布;
  • 不是直接使用数据来计算loss更新生成器,而是使用判别器的梯度,所以数据不会直接复制到生成器的参数中;

Pytorch代码

# 代码来源:https://github.com/eriklindernoren/PyTorch-GAN
import argparse
import os
import numpy as np
import math

import torchvision.transforms as transforms
from torchvision.utils import save_image

from torch.utils.data import DataLoader
from torchvision import datasets
from torch.autograd import Variable

import torch.nn as nn
import torch.nn.functional as F
import torch

os.makedirs("images", exist_ok=True)

parser = argparse.ArgumentParser()
parser.add_argument("--n_epochs", type=int, default=200, help="number of epochs of training")  # 迭代次数
parser.add_argument("--batch_size", type=int, default=64, help="size of the batches")  # 批量大小
parser.add_argument("--lr", type=float, default=0.0002, help="adam: learning rate")  # adam的学习率
parser.add_argument("--b1", type=float, default=0.5, help="adam: decay of first order momentum of gradient")  # 动量法
parser.add_argument("--b2", type=float, default=0.999, help="adam: decay of first order momentum of gradient")
parser.add_argument("--n_cpu", type=int, default=8, help="number of cpu threads to use during batch generation")
parser.add_argument("--latent_dim", type=int, default=100, help="dimensionality of the latent space")  # 生成器输入维度
parser.add_argument("--img_size", type=int, default=28, help="size of each image dimension")  # 照片的尺寸
parser.add_argument("--channels", type=int, default=1, help="number of image channels")  # 通道数,1表示灰度图
parser.add_argument("--sample_interval", type=int, default=400, help="interval betwen image samples")  # 采样照片频率
opt = parser.parse_args()
print(opt)

img_shape = (opt.channels, opt.img_size, opt.img_size)

cuda = True if torch.cuda.is_available() else False


class Generator(nn.Module):  # 生成器
    def __init__(self):
        super(Generator, self).__init__()

        def block(in_feat, out_feat, normalize=True):
            layers = [nn.Linear(in_feat, out_feat)]
            if normalize:
                layers.append(nn.BatchNorm1d(out_feat, 0.8))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return layers

        self.model = nn.Sequential(
            *block(opt.latent_dim, 128, normalize=False),
            *block(128, 256),
            *block(256, 512),
            *block(512, 1024),
            nn.Linear(1024, int(np.prod(img_shape))),
            nn.Tanh()
        )

    def forward(self, z):
        img = self.model(z)
        img = img.view(img.size(0), *img_shape)
        return img


class Discriminator(nn.Module):  # 判别器
    def __init__(self):
        super(Discriminator, self).__init__()

        self.model = nn.Sequential(
            nn.Linear(int(np.prod(img_shape)), 512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(256, 1),
            nn.Sigmoid(),
        )

    def forward(self, img):
        img_flat = img.view(img.size(0), -1)
        validity = self.model(img_flat)

        return validity


# Loss function
adversarial_loss = torch.nn.BCELoss()

# Initialize generator and discriminator
generator = Generator()
discriminator = Discriminator()

if cuda:
    generator.cuda()
    discriminator.cuda()
    adversarial_loss.cuda()

# Configure data loader
os.makedirs("../../data/mnist", exist_ok=True)
dataloader = torch.utils.data.DataLoader(
    datasets.MNIST(
        "../../data/mnist",
        train=True,  # 训练模式
        download=True,  # 如果MNIST没有下载则直接下载
        transform=transforms.Compose(
            [transforms.Resize(opt.img_size), transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]
        ),  # 照片处理方式
    ),  # 数据集
    batch_size=opt.batch_size,  # 训练数据批量大小
    shuffle=True,  # 是否打乱
)

# Optimizers
optimizer_G = torch.optim.Adam(generator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))  # 生成器的优化器
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))  # 判别器的优化器

Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor

# ----------
#  Training
# ----------

for epoch in range(opt.n_epochs):
    for i, (imgs, _) in enumerate(dataloader):

        # Adversarial ground truths
        valid = Variable(Tensor(imgs.size(0), 1).fill_(1.0), requires_grad=False)  # 真实数据的label
        fake = Variable(Tensor(imgs.size(0), 1).fill_(0.0), requires_grad=False)  # 生成数据的label

        # Configure input
        real_imgs = Variable(imgs.type(Tensor))  # 真实照片

        # -----------------
        #  Train Generator
        # -----------------

        optimizer_G.zero_grad()  #

        # Sample noise as generator input
        z = Variable(Tensor(np.random.normal(0, 1, (imgs.shape[0], opt.latent_dim))))  # 生成随机分布的数据

        # Generate a batch of images
        gen_imgs = generator(z)  # 生成器生成伪照片

        # Loss measures generator's ability to fool the discriminator
        g_loss = adversarial_loss(discriminator(gen_imgs), valid)  # 生成器的目的是骗过判别器,所以希望生成器生成的照片被预测为1

        g_loss.backward()
        optimizer_G.step()

        # ---------------------
        #  Train Discriminator
        # ---------------------

        optimizer_D.zero_grad()

        # Measure discriminator's ability to classify real from generated samples
        real_loss = adversarial_loss(discriminator(real_imgs), valid)  # 判别器希望真实的照片预测为1
        fake_loss = adversarial_loss(discriminator(gen_imgs.detach()), fake)  # 判别器希望伪造的照片预测为0
        d_loss = (real_loss + fake_loss) / 2

        d_loss.backward()
        optimizer_D.step()

        print(
            "[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f]"
            % (epoch, opt.n_epochs, i, len(dataloader), d_loss.item(), g_loss.item())
        )

        batches_done = epoch * len(dataloader) + i
        if batches_done % opt.sample_interval == 0:
            save_image(gen_imgs.data[:25], "images/%d.png" % batches_done, nrow=5, normalize=True)

            os.makedirs("model", exist_ok=True)
            torch.save(generator, 'model/generator.pkl')
            torch.save(discriminator, 'model/discriminator.pkl')

猜你喜欢

转载自blog.csdn.net/qq_37388085/article/details/115401477