keras 实现GAN(生成对抗网络)

本文将介绍如何在Keras中以最小的形式实现GAN。具体实现是一个深度卷积GAN,或DCGAN:一个GAN,其中generator和discriminator是深度卷积网络,它利用`Conv2DTranspose`层对generator中的图像上采样。
然后将在CIFAR10的图像上训练GAN,CIFAR10数据集由属于10个类别(每个类别5,000个图像)的50,000个32x32 RGB图像构成。为了节约时间,本文将只使用“frog”类的图像。
原理上,GAN的组成如下所示:
 *`generator`网络将shape`(latent_dim,)`的矢量映射到shape`(32,32,3)`的图像。
*“discriminator”网络将形状(32,32,3)的图像映射到估计图像是真实的概率的二进制分数。
 *`gan`网络将generator和discriminator链接在一起:`gan(x)=discriminator(generator(x))`。因此,这个“gan”网络将潜在空间向量映射到discriminator对由generator解码的这些潜在向量的真实性的评估。
*使用真实和虚假图像以及“真实”/“假”标签来训练鉴别器,因此需要训练任何常规图像分类模型。
 *为了训练generator,我们使用generator权重的梯度来减少“gan”模型的损失。这意味着,在每个step中,将generator的权重移动到使得discriminator更可能被分类为由generator解码的图像“真实”的方向上。即训练generator来欺骗discriminator。

一些技巧
训练和实现GAN实现是非常困难,应该记住一些已知的“技巧”。像深度学习中的大多数事情一样,它更像炼金术而不是科学:这些技巧实际上只是启发式,而不是理论支持的指导。他们得到了对手头现象的某种程度的直观理解的支持,并且他们知道在经验上很好地工作,尽管不一定在每种情况下。
 以下是在自己实现的GAN生成器和鉴别器中利用的一些技巧。
*使用`tanh`作为生成器中的最后一次激活,而不是`sigmoid`,这在其他类型的模型中更常见。
 *使用_normal distribution_(高斯分布)从潜在空间中采样点,而不是均匀分布。
 *随机性很好地诱导稳健性。由于GAN训练导致动态均衡,GAN可能会以各种方式“卡住”。
在训练期间引入随机性有助于防止这种情况。

以两种方式引入随机性:1)在鉴别器中使用dropout,2)在鉴别器的标签上添加一些随机噪声。
 *稀疏渐变可能会阻碍GAN训练。在深度学习中,稀疏性通常是理想的属性,但在GAN中则不然。有两件事可以引起梯度稀疏:1) max pooling操作,2)ReLU激活。建议使用跨步卷积进行下采样,而不是最大池,建议使用`LeakyReLU`层而不是ReLU激活。它类似于ReLU,但它通过允许小的负激活值来放宽稀疏性约束。
 *在生成的图像中,通常会看到由于生成器中像素空间的不均匀覆盖而导致的“棋盘格伪影”。为了解决这个问题,每当在生成器和鉴别器中使用跨步的`Conv2DTranpose`或`Conv2D`时,使用可被步长大小整除的内核大小。

下面示例代码:


# coding: utf-8

# In[6]:


'''
生成器(generator)
首先,创建一个“生成器(generator)”模型,它将一个矢量(从潜在空间 - 在训练期间随机采样)转换为候选图像。
GAN通常出现的许多问题之一是generator卡在生成的图像上,看起来像噪声。一种可能的解决方案是在鉴别器(discriminator)
和生成器(generator)上使用dropout。
'''
import keras
from keras import layers
import numpy as np

latent_dim = 32
height = 32
width = 32
channels = 3

generator_input = keras.Input(shape=(latent_dim,))

# 首先,将输入转换为16x16 128通道的feature map
x = layers.Dense(128 * 16 * 16)(generator_input)
x = layers.LeakyReLU()(x)
x = layers.Reshape((16, 16, 128))(x)

# 然后,添加卷积层
x = layers.Conv2D(256, 5, padding='same')(x)
x = layers.LeakyReLU()(x)

# 上采样至 32 x 32
x = layers.Conv2DTranspose(256, 4, strides=2, padding='same')(x)
x = layers.LeakyReLU()(x)

# 添加更多的卷积层
x = layers.Conv2D(256, 5, padding='same')(x)
x = layers.LeakyReLU()(x)
x = layers.Conv2D(256, 5, padding='same')(x)
x = layers.LeakyReLU()(x)

# 生成一个 32x32 1-channel 的feature map
x = layers.Conv2D(channels, 7, activation='tanh', padding='same')(x)
generator = keras.models.Model(generator_input, x)
generator.summary()


# In[8]:


'''
discriminator(鉴别器)
创建鉴别器模型,它将候选图像(真实的或合成的)作为输入,并将其分为两类:“生成的图像”或“来自训练集的真实图像”。
'''
discriminator_input = layers.Input(shape=(height, width, channels))
x = layers.Conv2D(128, 3)(discriminator_input)
x = layers.LeakyReLU()(x)
x = layers.Conv2D(128, 4, strides=2)(x)
x = layers.LeakyReLU()(x)
x = layers.Conv2D(128, 4, strides=2)(x)
x = layers.LeakyReLU()(x)
x = layers.Conv2D(128, 4, strides=2)(x)
x = layers.LeakyReLU()(x)
x = layers.Flatten()(x)

# 重要的技巧(添加一个dropout层)
x = layers.Dropout(0,4)(x)

# 分类层
x = layers.Dense(1, activation='sigmoid')(x)

discriminator = keras.models.Model(discriminator_input, x)
discriminator.summary()


# In[11]:


# 为了训练稳定,在优化器中使用学习率衰减和梯度限幅(按值)。
discriminator_optimizer = keras.optimizers.RMSprop(lr=8e-4, clipvalue=1.0, decay=1e-8)
discriminator.compile(optimizer=discriminator_optimizer, loss='binary_crossentropy')


# In[16]:


'''
The adversarial network:对抗网络
最后,设置GAN,它链接生成器(generator)和鉴别器(discrimitor)。 这是一种模型,经过训练,
将使生成器(generator)朝着提高其愚弄鉴别器(discrimitor)能力的方向移动。 该模型将潜在的空间点转换为分类决策,
“假的”或“真实的”,并且意味着使用始终是“这些是真实图像”的标签来训练。 所以训练`gan`将以一种方式更新
“发生器”的权重,使得“鉴别器”在查看假图像时更可能预测“真实”。 非常重要的是,将鉴别器设置为在训练
期间被冻结(不可训练):训练“gan”时其权重不会更新。 如果在此过程中可以更新鉴别器权重,那么将训练鉴别
器始终预测“真实”。
'''
# 将鉴别器(discrimitor)权重设置为不可训练(仅适用于`gan`模型)
discriminator.trainable = False

gan_input = keras.Input(shape=(latent_dim,))
gan_output = discriminator(generator(gan_input))
gan = keras.models.Model(gan_input, gan_output)

gan_optimizer = keras.optimizers.RMSprop(lr=4e-4, clipvalue=1.0, decay=1e-8)
gan.compile(optimizer=gan_optimizer, loss='binary_crossentropy')


# In[19]:


'''
  开始训练了。
  每个epoch:
   *在潜在空间中绘制随机点(随机噪声)。
   *使用此随机噪声生成带有“generator”的图像。
   *将生成的图像与实际图像混合。
   *使用这些混合图像训练“鉴别器”,使用相应的目标,“真实”(对于真实图像)或“假”(对于生成的图像)。
   *在潜在空间中绘制新的随机点。
   *使用这些随机向量训练“gan”,目标都是“这些是真实的图像”。 这将更新发生器的权重(仅因为鉴别器在“gan”内被冻结)
   以使它们朝向获得鉴别器以预测所生成图像的“这些是真实图像”,即这训练发生器欺骗鉴别器。
'''
import os
from keras.preprocessing import image

# 导入CIFAR10数据集
(x_train, y_train), (_, _) = keras.datasets.cifar10.load_data()

# 从CIFAR10数据集中选择frog类(class 6)
x_train = x_train[y_train.flatten() == 6]

# 标准化数据
x_train = x_train.reshape(
    (x_train.shape[0],) + (height, width, channels)).astype('float32') / 255.

iterations = 10000
batch_size = 20
save_dir = '.\\gan_image'

start = 0 
# 开始训练迭代
for step in range(iterations):
    # 在潜在空间中抽样随机点
    random_latent_vectors = np.random.normal(size=(batch_size, latent_dim))
    
    # 将随机抽样点解码为假图像
    generated_images = generator.predict(random_latent_vectors)
    
    # 将假图像与真实图像进行比较
    stop = start + batch_size
    real_images = x_train[start: stop]
    combined_images = np.concatenate([generated_images, real_images])
    
    # 组装区别真假图像的标签
    labels = np.concatenate([np.ones((batch_size, 1)),
                            np.zeros((batch_size, 1))])
    # 重要的技巧,在标签上添加随机噪声
    labels += 0.05 * np.random.random(labels.shape)
    
    # 训练鉴别器(discrimitor)
    d_loss = discriminator.train_on_batch(combined_images, labels)
    
    # 在潜在空间中采样随机点
    random_latent_vectors = np.random.normal(size=(batch_size, latent_dim))
    
    # 汇集标有“所有真实图像”的标签
    misleading_targets = np.zeros((batch_size, 1))
    
    # 训练生成器(generator)(通过gan模型,鉴别器(discrimitor)权值被冻结)
    a_loss = gan.train_on_batch(random_latent_vectors, misleading_targets)
    
    start += batch_size
    if start > len(x_train) - batch_size:
        start = 0
    if step % 100 == 0:
        # 保存网络权值
        gan.save_weights('gan.h5')

        # 输出metrics
        print('discriminator loss at step %s: %s' % (step, d_loss))
        print('adversarial loss at step %s: %s' % (step, a_loss))

        # 保存生成的图像
        img = image.array_to_img(generated_images[0] * 255., scale=False)
        img.save(os.path.join(save_dir, 'generated_frog' + str(step) + '.png'))

        # 保存真实图像,以便进行比较
        img = image.array_to_img(real_images[0] * 255., scale=False)
        img.save(os.path.join(save_dir, 'real_frog' + str(step) + '.png'))


# In[ ]:


# 绘图
import matplotlib.pyplot as plt

# 在潜在空间中抽样随机点
random_latent_vectors = np.random.normal(size=(10, latent_dim))

# 将随机抽样点解码为假图像
generated_images = generator.predict(random_latent_vectors)

for i in range(generated_images.shape[0]):
    img = image.array_to_img(generated_images[i] * 255., scale=False)
    plt.figure()
    plt.imshow(img)
    
plt.show()

结果:

猜你喜欢

转载自blog.csdn.net/github_39611196/article/details/84198545