Tensorflow2.0实现对抗生成网络(GAN)

在这篇文章中,我们使用Tensorflow2.0来实现GAN,使用的数据集是手写数字数据集。

引入需要的库

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import matplotlib.pyplot as plt
%matplotlib inline

导入数据,归一化数据

(train_images, train_labels), (_, _) = keras.datasets.mnist.load_data()
train_images = train_images.reshape(train_images.shape[0], 28, 28, 1).astype('float32')
train_images = (train_images-127.5)/127.5

BATCH_SIZE = 256
BUFFER_SIZE = 60000

datasets = tf.data.Dataset.from_tensor_slices(train_images)
datasets = datasets.shuffle(BUFFER_SIZE).batch(BATCH_SIZE)

建立生成器

def generator_model():  # 用100个随机数(噪音)生成手写数据集
    model = keras.Sequential()
    model.add(layers.Dense(256, input_shape=(100,), use_bias=False))
    model.add(layers.BatchNormalization())
    model.add(layers.LeakyReLU())
    
    model.add(layers.Dense(512, use_bias=False))
    model.add(layers.BatchNormalization())
    model.add(layers.LeakyReLU())
    
    model.add(layers.Dense(28*28*1, use_bias=False, activation='tanh'))
    model.add(layers.BatchNormalization())
    
    model.add(layers.Reshape((28, 28, 1)))
    
    return model

建立判别器

def discriminator_model():  # 识别输入的图片
    model = keras.Sequential()
    model.add(layers.Flatten())
    
    model.add(layers.Dense(512, use_bias=False))
    model.add(layers.BatchNormalization())
    model.add(layers.LeakyReLU())
    
    model.add(layers.Dense(256, use_bias=False))
    model.add(layers.BatchNormalization())
    model.add(layers.LeakyReLU())
    
    model.add(layers.Dense(1))
    
    return model

分别定义判别器和生成器的损失函数

对于判别器来说,我们需要将导入的原始图片识别为真(1),将生成器胜场的图像识别为假(0)。
对于生成器来说,我们需要使得生成的图片无限接近于真实图片。

cross_entropy = keras.losses.BinaryCrossentropy(from_logits=True)

def discriminator_loss(real_out, fake_out):
    real_loss = cross_entropy(tf.ones_like(real_out), real_out)
    fake_loss = cross_entropy(tf.zeros_like(fake_out), fake_out)
    return real_loss + fake_loss

def generator_loss(fake_out):
    return cross_entropy(tf.ones_like(fake_out), fake_out)

在以上代码中,real_out是指向判别器输入原始图像得到的结果;fake_out是指向判别器输入生成图像得到的结果。
所以对于判别器的损失函数来说,real_out应该无限接近于1;fake_out应该无限接近于0。即我们想训练出的判别器应该对图片有很高的识别能力。
但对于生成器的损失函数来说,fake_out应该无限接近于1,也就是令判别器很难分辨出生成的图片。
【注】keras.losses.BinaryCrossentropy(from_logits=True)的用法可以参考:tensorflow2.0中损失函数tf.keras.losses.BinaryCrossentropy()的用法

分别定义生成器和判别器的优化函数

generator_opt = keras.optimizers.Adam(1e-4)
discriminator_opt = keras.optimizers.Adam(1e-4)

实例化生成器和判别器

generator = generator_model()
discriminator = discriminator_model()

定义训练过程

noise_dim = 100  # 即用100个随机数生成图片

def train_step(images):
    noise = tf.random.normal([BATCH_SIZE, noise_dim])
    
    with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
        real_out = discriminator(images, training=True)
    
        gen_image = generator(noise, training=True)
        fake_out = discriminator(gen_image, training=True)
        
        gen_loss = generator_loss(fake_out)
        disc_loss = discriminator_loss(real_out, fake_out)
    gradient_gen = gen_tape.gradient(gen_loss, generator.trainable_variables)
    gradient_disc = disc_tape.gradient(disc_loss, discriminator.trainable_variables)
    generator_opt.apply_gradients(zip(gradient_gen, generator.trainable_variables))
    discriminator_opt.apply_gradients(zip(gradient_disc, discriminator.trainable_variables))

gradient_gen = gen_tape.gradient(gen_loss, generator.trainable_variables)表示计算gen_loss对于generator的所有变量的梯度。
generator_opt.apply_gradients(zip(gradient_gen, generator.trainable_variables))表示根据gradient_gen来优化generator的变量。
【注】梯度带及梯度更新的用法参考:Tensorflow中的梯度带(GradientTape)以及梯度更新

定义绘图函数

def generate_plot_image(gen_model, test_noise):
    pre_images = gen_model(test_noise, training=False)
    fig = plt.figure(figsize=(4, 4))
    for i in range(pre_images.shape[0]):
        plt.subplot(4, 4, i+1)
        plt.imshow((pre_images[i, :, :, 0] + 1)/2, cmap='gray')
        plt.axis('off')
    plt.show()

plt.imshow((pre_images[i, :, :, 0] + 1)/2, cmap=‘gray’)
这里是因为我们使用tanh激活函数之后会将结果限制在-1到1之间,而我们需要将其转化到0到1之间。

定义训练函数

EPOCHS = 100  # 训练100次
num_exp_to_generate = 16  # 生成16张图片
seed = tf.random.normal([num_exp_to_generate, noise_dim])  # 16组随机数组,每组含100个随机数,用来生成16张图片。

def train(dataset, epochs):
    for epoch in range(epochs):
        for image_batch in dataset:
            train_step(image_batch)
            print('.', end='')
        generate_plot_image(generator, seed)
train(datasets, EPOCHS)
发布了116 篇原创文章 · 获赞 13 · 访问量 3万+

猜你喜欢

转载自blog.csdn.net/qq_36758914/article/details/104613596
今日推荐