生成式对抗网络(GAN):基于keras的python学习笔记(3-1)

版权声明:本文为博主原创文章,未经博主允许不得转载。https://blog.csdn.net/weixin_44474718/article/details/88829365
GAN(Generative Adversarial Networks):全为全连接层

CGAN(Conditional):条件GAN

DCGAN(Deep Convolution)2015:卷积层代替池化层,去除全连接层
、使用批归一化、使用恰当的激活函数。
WGAN(Wasserstein):在DCGAN的基础上改进损失函数,解决训练不稳定、模式崩溃的问题,并且生成结果更加多样。提供了一个可以衡量GAN训练好坏的指标。

WGAN-GP(gradient penalty):改进了连续性限制的条件,使用梯度惩罚的方式以满足此连续性条件。解决了训练梯度消失梯度爆炸的问题,比标准WGAN拥有更快的收敛速度,并能生成更高质量的样本,提供稳定的GAN训练方式,几乎不需要怎么调参,成功训练多种针对图片生成和语言模型的GAN架构。(用的多)
LSGAN(Least Squares):最小二乘GAN,采用最小二乘损失函数代替GAN的交叉熵,解决了训练不稳定、图像质量差、多样性不足的问题。

BEGAN(Boundary Equilibrium):快速稳定训练,用误差分布差距来评价生成器生成器质量。(适合高清图像)
SGAN(半监督式):同时训练生成器和半监督式分类器。
ACGAN(辅助分类):结合CGAN和SGAN.
infoGAN():使生成模型可以产生有意义且可解释的特征,让编码中的每一维都具有实际意义,训练时间更短。
Pix2Pix:(需要数据集匹配),完成成对的图像转换(灰度图、梯度图、彩色图之间的转换),可以得到比较清晰的结果。
CycleGAN/DiscoGAN/DualGAN:解决非匹配数据集的图像转换。
DRAGAN :则引入博弈论中的无后悔算法,改造其 loss 以解决 mode collapse问题 [9]。前文所述的 EBGAN 则是加入 VAE 的重构误差以解决 mode collapse。
Multi agent diverse GAN (MAD-GAN) :采用多个生成器,一个判别器,以保障样本生成的多样性。

目标函数
f-divergence: GAN LSGAN EBGAN

Integral probality metric (IPM): WGAN,WGAN-GP,Fisher GAN 比上面的好
在许多 GAN 的应用中,会使用额外的 Loss 用于稳定训练或者达到其他的目的。比如在图像翻译,图像修复,超分辨当中,生成器会加入目标图像作为监督信息。EBGAN 则把 GAN 的判别器作为一个能量函数,在判别器中加入重构误差。CGAN 则使用类别标签信息作为监督信息。

自回归模型:pixelRNN与pixelCNN 由于其像素值是一个个生成的,速度会很慢

DCGAN 提出使用 CNN 结构来稳定 GAN 的训练,并使用了以下一些 trick:
Batch Normalization
使用 Transpose convlution 进行上采样
使用 Leaky ReLu 作为激活函数
上面这些 trick 对于稳定 GAN 的训练有许多帮助,自己设计 GAN 网络时也可以酌情使用。

自编码结构
经典的 GAN 结构里面,判别网络通常被当做一种用于区分真实/生成样本的概率模型。而在自编码器结构里面,判别器(使用 AE 作为判别器)通常被当做能量函数(Energy function)。对于离数据流形空间比较近的样本,其能量较小,反之则大。有了这种距离度量方式,自然就可以使用判别器去指导生成器的学习。
AE 的 loss 是一个重构误差。使用 AE 做为判别器时,如果输入真实样本,其重构误差会很小。如果输入生成的样本,其重构误差会很大。因为对于生成的样本,AE 很难学习到一个图像的压缩表示(即生成的样本离数据流行形空间很远)。所以,VAE 的重构误差作为 Pdata 和 Pg 之间的距离度量是合理的。典型的自编码器结构的 GAN 有:BEGAN,EBGAN,MAGAN 等。

GAN 相比于 VAE 可以生成清晰的图像,但是却容易出现 mode collapse 问题。VAE 由于鼓励重构所有样本,所以不会出现 mode collapse 问题。
一个典型结合二者的工作是 VAEGAN,

# wGAN的简单实现

from __future__ import print_function, division
from keras.datasets import mnist
from keras.layers import Input, Dense, Reshape, Flatten, Dropout
from keras.layers import BatchNormalization, Activation, ZeroPadding2D
from keras.layers.advanced_activations import LeakyReLU
from keras.layers.convolutional import UpSampling2D, Conv2D
from keras.models import Sequential, Model
from keras.optimizers import RMSprop
import keras.backend as K
import matplotlib.pyplot as plt
import sys
import numpy as np


class WGAN():
    def __init__(self):
        self.img_rows = 28
        self.img_cols = 28
        self.channels = 1
        self.img_shape = (self.img_rows, self.img_cols, self.channels)
        self.latent_dim = 100

        # Following parameter and optimizer set as recommended in paper
        self.n_critic = 5
        self.clip_value = 0.01
        optimizer = RMSprop(lr=0.00005)

        # Build and compile the critic
        self.critic = self.build_critic()
        self.critic.compile(loss=self.wasserstein_loss,
                            optimizer=optimizer,
                            metrics=['accuracy'])

        # Build the generator
        self.generator = self.build_generator()

        # The generator takes noise as input and generated imgs
        z = Input(shape=(self.latent_dim,))
        img = self.generator(z)

        # For the combined model we will only train the generator
        self.critic.trainable = False

        # The critic takes generated images as input and determines validity
        valid = self.critic(img)

        # The combined model  (stacked generator and critic)
        self.combined = Model(z, valid)
        self.combined.compile(loss=self.wasserstein_loss,
                              optimizer=optimizer,
                              metrics=['accuracy'])

    def wasserstein_loss(self, y_true, y_pred):
        return K.mean(y_true * y_pred)

    def build_generator(self):

        model = Sequential()

        model.add(Dense(128 * 7 * 7, activation="relu", input_dim=self.latent_dim))
        model.add(Reshape((7, 7, 128)))
        model.add(UpSampling2D())
        model.add(Conv2D(128, kernel_size=4, padding="same"))
        model.add(BatchNormalization(momentum=0.8))
        model.add(Activation("relu"))
        model.add(UpSampling2D())
        model.add(Conv2D(64, kernel_size=4, padding="same"))
        model.add(BatchNormalization(momentum=0.8))
        model.add(Activation("relu"))
        model.add(Conv2D(self.channels, kernel_size=4, padding="same"))
        model.add(Activation("tanh"))

        model.summary()

        noise = Input(shape=(self.latent_dim,))
        img = model(noise)

        return Model(noise, img)

    def build_critic(self):

        model = Sequential()

        model.add(Conv2D(16, kernel_size=3, strides=2, input_shape=self.img_shape, padding="same"))
        model.add(LeakyReLU(alpha=0.2))
        model.add(Dropout(0.25))
        model.add(Conv2D(32, kernel_size=3, strides=2, padding="same"))
        model.add(ZeroPadding2D(padding=((0, 1), (0, 1))))
        model.add(BatchNormalization(momentum=0.8))
        model.add(LeakyReLU(alpha=0.2))
        model.add(Dropout(0.25))
        model.add(Conv2D(64, kernel_size=3, strides=2, padding="same"))
        model.add(BatchNormalization(momentum=0.8))
        model.add(LeakyReLU(alpha=0.2))
        model.add(Dropout(0.25))
        model.add(Conv2D(128, kernel_size=3, strides=1, padding="same"))
        model.add(BatchNormalization(momentum=0.8))
        model.add(LeakyReLU(alpha=0.2))
        model.add(Dropout(0.25))
        model.add(Flatten())
        model.add(Dense(1))

        model.summary()       # 输出模型的各参数情况

        img = Input(shape=self.img_shape)
        validity = model(img)

        return Model(img, validity)

    def train(self, epochs, batch_size=128, sample_interval=50):

        # Load the dataset
        (X_train, _), (_, _) = mnist.load_data()

        # Rescale -1 to 1
        X_train = (X_train.astype(np.float32) - 127.5) / 127.5
        X_train = np.expand_dims(X_train, axis=3)

        # Adversarial ground truths
        valid = -np.ones((batch_size, 1))
        fake = np.ones((batch_size, 1))

        for epoch in range(epochs):

            for _ in range(self.n_critic):

                # ---------------------
                #  Train Discriminator
                # ---------------------

                # Select a random batch of images
                idx = np.random.randint(0, X_train.shape[0], batch_size)
                imgs = X_train[idx]

                # Sample noise as generator input
                noise = np.random.normal(0, 1, (batch_size, self.latent_dim))

                # Generate a batch of new images
                gen_imgs = self.generator.predict(noise)

                # Train the critic
                d_loss_real = self.critic.train_on_batch(imgs, valid)
                d_loss_fake = self.critic.train_on_batch(gen_imgs, fake)
                d_loss = 0.5 * np.add(d_loss_fake, d_loss_real)

                # Clip critic weights
                for l in self.critic.layers:
                    weights = l.get_weights()
                    weights = [np.clip(w, -self.clip_value, self.clip_value) for w in weights]
                    l.set_weights(weights)

            # ---------------------
            #  Train Generator
            # ---------------------

            g_loss = self.combined.train_on_batch(noise, valid)

            # Plot the progress
            print("%d [D loss: %f] [G loss: %f]" % (epoch, 1 - d_loss[0], 1 - g_loss[0]))

            # If at save interval => save generated image samples
            if epoch % sample_interval == 0:
                self.sample_images(epoch)

    def sample_images(self, epoch):
        r, c = 5, 5
        noise = np.random.normal(0, 1, (r * c, self.latent_dim))
        gen_imgs = self.generator.predict(noise)

        # Rescale images 0 - 1
        gen_imgs = 0.5 * gen_imgs + 0.5

        fig, axs = plt.subplots(r, c)
        cnt = 0
        for i in range(r):
            for j in range(c):
                axs[i, j].imshow(gen_imgs[cnt, :, :, 0], cmap='gray')
                axs[i, j].axis('off')
                cnt += 1
        fig.savefig("images2/mnist_%d.png" % epoch)
        plt.close()


if __name__ == '__main__':
    wgan = WGAN()
    wgan.train(epochs=8000, batch_size=32, sample_interval=50)

运行了四个小时

猜你喜欢

转载自blog.csdn.net/weixin_44474718/article/details/88829365