基于Keras的GAN面向对象代码

keras相比于疼搜人flow代码更加简洁,减少了开发的成本。在开始之前建议大家先去了解下keras. 这里仍然是以最简单的手写数字集合为例。 也可以用cifar10, celebA, 或者别的数据集合。


from __future__ import print_function, division

from keras.datasets import mnist
from keras.layers import Input, Dense, Reshape, Flatten, Dropout
from keras.layers import BatchNormalization, Activation, ZeroPadding2D
from keras.layers.advanced_activations import LeakyReLU
from keras.layers.convolutional import UpSampling2D, Conv2D
from keras.models import Sequential, Model
from keras.optimizers import Adam

import matplotlib.pyplot as plt

import sys

import numpy as np
from tensorflow.examples.tutorials.mnist import input_data


class GAN():
    def __init__(self):
        # 手写数字图片的大小
        self.img_rows = 28
        self.img_cols = 28
        self.channels = 1
        self.img_shape = (self.img_rows, self.img_cols, self.channels)
        # 生成对抗网络的噪声长度
        self.latent_dim = 100
        # 优化器
        optimizer = Adam(0.0002, 0.5)

        # 构建并且编译判别器
        self.discriminator = self.build_discriminator()
        self.discriminator.compile(loss='binary_crossentropy',
            optimizer=optimizer,
            metrics=['accuracy'])

        # 构造生成器、判别器的堆叠,  z --->img---->validity, 使得由噪音可以直接得到判别结果,不用每次
        #都要先生成图片,在调用判别器判定
        self.generator = self.build_generator()
        # z为生成器的输入,长度是100的噪声
        z = Input(shape=(self.latent_dim,))
        # 将z放入生成器中,最终会得到生成器的输出, 函数式模型的特性
        img = self.generator(z)
        # For the combined model we will only train the generator
        self.discriminator.trainable = False
        # 判别器对图片的判别结果
        validity = self.discriminator(img)
        self.combined = Model(z, validity)
        # 进行编译, 可看成2分类问题,真实图片为1, 生成图片为0
        self.combined.compile(loss='binary_crossentropy', optimizer=optimizer)

    # 生成器函数
    def build_generator(self):
        # 序惯模型构建生成器
        model = Sequential()
        # 增加256个神经元
        model.add(Dense(256, input_dim=self.latent_dim))
        # 使用relu激活函数
        model.add(Activation("relu"))
        # 正则化
        model.add(Dropout(0.4))
        model.add(BatchNormalization(momentum=0.8))
        model.add(Dense(512))
        model.add(Activation("relu"))
        model.add(Dropout(0.4))
        # batchNormalization 处理, 减少随机生成参数出现的误差
        model.add(BatchNormalization(momentum=0.8))
        # 增加1024个神经元,使用relu激活函数
        model.add(Dense(1024))
        model.add(Activation("relu"))
        model.add(Dropout(0.4))
        model.add(BatchNormalization(momentum=0.8))
        # 相当于增加28x28x1个神经元
        model.add(Dense(np.prod(self.img_shape), activation='tanh'))
        # 使输出变为28x28x1的图片
        model.add(Reshape(self.img_shape))
        # 输出模型结构
        model.summary()

        noise = Input(shape=(self.latent_dim,))
        img = model(noise)
        # 返回结果 --->给定一个噪音,就会生成一个28x28x1的图片
        return Model(noise, img)

    # 判别器函数
    def build_discriminator(self):
        # 序贯模型
        model = Sequential()
        # 将输入的图片变为一维数组
        model.add(Flatten(input_shape=self.img_shape))
        # 增加神经元节点
        model.add(Dense(512))
        model.add(Activation("relu"))
        model.add(Dense(256))
        model.add(Activation("relu"))
        model.add(Dropout(0.4))
        # 最后一层使用sigmoid激活
        model.add(Dense(1, activation='sigmoid'))
        model.summary()

        img = Input(shape=self.img_shape)
        validity = model(img)
        # 返回的结果类似一个函数,---给定一个28x28x1的图片输出,产生一个概率值
        return Model(img, validity)

    # 开始训练
    def train(self, epochs, batch_size=128, sample_interval=50):

        # 加载数据集合,这里有一个巨坑, tensorflow中的mnist图片读进来之后像素值0,1, Keras中是0-255,
        # 使用tensorflow读取数据的时候要在像素值后面乘以255
        X_train = input_data.read_data_sets("MNIST_DATA", one_hot=True).train.images
        X_train = np.reshape(X_train, (-1, 28, 28, 1)) * 255
        X_train = X_train / 127.5 - 1.


        # 真实图片的标签为1
        valid = np.ones((batch_size, 1))
        # 生成图案标签为0
        fake = np.zeros((batch_size, 1))

        for epoch in range(epochs):
            # 从训练数据中随机选择batch_size张
            idx = np.random.randint(0, X_train.shape[0], batch_size)
            imgs = X_train[idx]

            # 生成batch_size 个噪声
            noise = np.random.normal(0, 1, (batch_size, self.latent_dim))

            # 将噪声输入到生成器中生成图片
            gen_imgs = self.generator.predict(noise)

            # 定义生成器和判别器的损失, 判别器努力分辨出真假图片
            d_loss_real = self.discriminator.train_on_batch(imgs, valid)
            d_loss_fake = self.discriminator.train_on_batch(gen_imgs, fake)
            d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)

            noise = np.random.normal(0, 1, (batch_size, self.latent_dim))
            # 生成器, self.combined 的目的是给出一个噪声可以直接获得预测结果,也可以拆开
            # 努力使得生成的图片更接近1,和判别是一个对抗过程
            g_loss = self.combined.train_on_batch(noise, valid)

            # Plot the progress
            print ("%d [D loss: %f, acc.: %.2f%%] [G loss: %f]" % (epoch, d_loss[0], 100*d_loss[1], g_loss))

            # If at save interval => save generated image samples
            if epoch % sample_interval == 0:
                self.sample_images(epoch)

    # 画图函数
    def sample_images(self, epoch):
        r, c = 5, 5
        noise = np.random.normal(0, 1, (r * c, self.latent_dim))
        gen_imgs = self.generator.predict(noise)

        # Rescale images 0 - 1
        gen_imgs = 0.5 * gen_imgs + 0.5
        fig, axs = plt.subplots(r, c)
        cnt = 0
        for i in range(r):
            for j in range(c):
                axs[i,j].imshow(gen_imgs[cnt, :,:,0], cmap='gray')
                axs[i,j].axis('off')
                cnt += 1
        fig.savefig("images/%d.png" % epoch)
        plt.close()


if __name__ == '__main__':
    gan = GAN()
    gan.train(epochs=20000, batch_size=32, sample_interval=100)

猜你喜欢

转载自blog.csdn.net/qq_41559533/article/details/83784216