DeepChem教程15: 用MNISTT数据集训练GAN网络

这个教程我们用MNIST数据集来训练成生对抗网格(GAN)。MNIST28x28像素手写数字图像的大的集合。我们将训练网格来产生新的手写数字图像。

In [ ]:

!curl -Lo conda_installer.py https://raw.githubusercontent.com/deepchem/deepchem/master/scripts/colab_install.py

import conda_installer

conda_installer.install()

!/root/miniconda/bin/conda info -e

In [ ]:

!pip install --pre deepchem

import deepchem

deepchem.__version__

开始,我们需要导入所有我们需要的库并加载数据集(数据集来自Tensorflow

In [1]:

import deepchem as dc

import tensorflow as tf

from deepchem.models.optimizers import ExponentialDecay

from tensorflow.keras.layers import Conv2D, Conv2DTranspose, Dense, Reshape

import matplotlib.pyplot as plot

import matplotlib.gridspec as gridspec

%matplotlib inline

mnist = tf.keras.datasets.mnist.load_data(path='mnist.npz')

images = mnist[0][0].reshape((-1, 28, 28, 1))/255

dataset = dc.data.NumpyDataset(images)

我们来看一下图像是什么样子的。

In [2]:

def plot_digits(im):

  plot.figure(figsize=(3, 3))

  grid = gridspec.GridSpec(4, 4, wspace=0.05, hspace=0.05)

  for i, g in enumerate(grid):

    ax = plot.subplot(g)

    ax.set_xticks([])

    ax.set_yticks([])

    ax.imshow(im[i,:,:,0], cmap='gray')

plot_digits(images)

现在我们来创建自已的GAN。像上一个教程一样,它包含两部分:

1.生成器以随机噪音为输入,产生与训练数据相似的输出。

2.分判器以一些样本作为输入(可能是训练样本也可能是生成器产生的样本),并尽量分判真假。

这次我们使用不同风格的GAN叫做Wasserstein GAN (或简称WGAN)。很多情况下,它们被发现能产生比条件GAN更好的结果。这两者的主要区别在于分判器(本文叫"critic")。不是输出样本为真实训练数据的概率,它尽量学习如何测量训练分布与生成分布的距离。这种测量然后被用作损失函数来训练生成器。

我们使用非常简单的模型。生成器变换输入噪音到有8个通道的7x7图像。随后有两个卷积层,先上取样到14x14,最后到28x28

分判器做的事情大致相同只是反过来。两个卷积屋下取样图像到14x14,然后到7x7。最后一个全链接层产生数字作为输出。上一个教程我们使用sigmoid激活函数,产生01之间的数,可以被解释为概率。因为这是一个WGAN,我们使用softplus激活函数。它产生不平衡的正数可被解释为距离。

In [3]:

class DigitGAN(dc.models.WGAN):

  def get_noise_input_shape(self):

    return (10,)

  def get_data_input_shapes(self):

    return [(28, 28, 1)]

  def create_generator(self):

    return tf.keras.Sequential([

        Dense(7*7*8, activation=tf.nn.relu),

        Reshape((7, 7, 8)),

        Conv2DTranspose(filters=16, kernel_size=5, strides=2, activation=tf.nn.relu, padding='same'),

        Conv2DTranspose(filters=1, kernel_size=5, strides=2, activation=tf.sigmoid, padding='same')

    ])

  def create_discriminator(self):

    return tf.keras.Sequential([

        Conv2D(filters=32, kernel_size=5, strides=2, activation=tf.nn.leaky_relu, padding='same'),

        Conv2D(filters=64, kernel_size=5, strides=2, activation=tf.nn.leaky_relu, padding='same'),

        Dense(1, activation=tf.math.softplus)

    ])

gan = DigitGAN(learning_rate=ExponentialDecay(0.001, 0.9, 5000))

现在来训练它。就像上一个教程,我们写一个生成器来产生数据。这次数据来自数据集,我们用数据迭代表100次。

另一个不同点并不重要。训练传统的GAN时,重要的是保持生成器和分判器在整个训练过程的平衡。任意一个走得过快,另一个就会很难学习。

WGANs不会有这个问题。事实上,分判器越好,它给出的信号越清晰,它就越容易被生成器学习。因此我们指定generator_steps=0.2以至它仅采取一步训练生成器每五步训练分判器。这趋于产生更快的训练和更好的结果。

In [4]:

def iterbatches(epochs):

  for i in range(epochs):

    for batch in dataset.iterbatches(batch_size=gan.batch_size):

      yield {gan.data_inputs[0]: batch[0]}

gan.fit_gan(iterbatches(100), generator_steps=0.2, checkpoint_interval=5000)

Ending global_step 4999: generator average loss 0.340072, discriminator average loss -0.0234236

Ending global_step 9999: generator average loss 0.52308, discriminator average loss -0.00702729

Ending global_step 14999: generator average loss 0.572661, discriminator average loss -0.00635684

Ending global_step 19999: generator average loss 0.560454, discriminator average loss -0.00534357

Ending global_step 24999: generator average loss 0.556055, discriminator average loss -0.00620613

Ending global_step 29999: generator average loss 0.541958, discriminator average loss -0.00734233

Ending global_step 34999: generator average loss 0.540904, discriminator average loss -0.00736641

Ending global_step 39999: generator average loss 0.524298, discriminator average loss -0.00650514

Ending global_step 44999: generator average loss 0.503931, discriminator average loss -0.00563732

Ending global_step 49999: generator average loss 0.528964, discriminator average loss -0.00590612

Ending global_step 54999: generator average loss 0.510892, discriminator average loss -0.00562366

Ending global_step 59999: generator average loss 0.494756, discriminator average loss -0.00533636

TIMING: model fitting took 4197.860 s

Let's generate some data and see how the results look.

In [5]:

Plot_digits(gan.predict_gan_generator(batch_size=16))

不错,许多生成的图像看起来像手写数字。模型越大训练时间越长结果当然更好。

下载全文请到www.data-vision.net,技术联系电话13712566524

猜你喜欢

转载自blog.csdn.net/lishaoan77/article/details/114334632