【GAN】一、利用keras实现DCGAN生成手写数字图像

概要

目前仍然在在广州的实习公司继续实习,为了更好的完成任务,以及未来的开题,现在必须仔细学习GAN。之前将GAN和DCGAN两篇论文仔细阅读完了,之后为了检验学习成果写下了这份DCGAN生成手写数字的代码。

虽然是GAN系列的第一篇文章,本想着先从GAN最初论文说起,但是由于很久没更新了博客或者公众号了,想赶紧更新一篇回馈粉丝。巧合的是,GAN的结果展示多为各种图片,转念一想利用讲解代码和展示结果方式来引导GAN系列的开始也不为一个合适的选择。下面开始介绍利用DCGAN生成书写数字。

一、GAN简介

虽然GAN系列第一篇上来就讲解代码,着实让很多小白们难易难以接受。因此我也首先简单介绍一下GAN的原理。

GAN(Generative Adversarial Network)全名叫做对抗生成网络或者生成对抗网络。GAN这一概念是由Ian Goodfellow于2014年提出,并迅速成为了非常火热的研究话题。目前,GAN的变种更是有上千种,2019年计算机界的诺贝尔奖“图灵奖”得主,深度学习先驱之一的Yann LeCun也曾说:“GAN及其变种是数十年来机器学习领域最有趣的想法。”

GAN的主要思想是零和博弈,GAN有两部分组成,一个生成器和一个判别器。生成器主要用于生成图像,判别器用于判别图像是否是“假的”,即图像是否由生成器的概率。GAN的训练可以看成式生成器与判别器之间相互对抗的过程。那么最理想的结果是生成器生成的图像在判别器的预测结果为0.5,即分不清图像是真实图像还是生成器生成的图像。

在原始GAN中,判别器与生成器都是原始的多层感知机即BP神经网络,在DCGAN模型中,BP神经网络都被卷积神经网络所替换。生成器主要是利用一系列反卷积操作将一维噪声向量转化成图像,判别器则是正常的卷积神经网络,将图像进行一系列提取特征之后在判断该图像来自生成器的概率。


二、DCGAN源代码

接下来,我们来介绍利用DCGAN生成手写数字图像。本篇文章的代码全部使用keras进行编写,后端使用的是tensorflow1.14。该项目的源代码网址请移步:DCGAN-mnist

首先给出DCGAN的类代码,这份代码主要由初始化函数、生成器搭建函数、判别器搭建函数,DCGAN的训练函数和保存DCGAN生成的图片的函数5部分构成。代码如下所示:

# -*- coding: utf-8 -*-
# @Time    : 2019/9/15 9:26
# @Author  : DaiPuWei
# @Email   : [email protected]
# @Blog    : https://daipuweiai.blog.csdn.net/
# @File    : DCGAN.py
# @Software: PyCharm

import os
import numpy as np
import matplotlib.pyplot as plt

from scipy.stats import truncnorm
from keras import Model
from keras import Input
from keras import Sequential
from keras.layers import Conv2D
from keras.layers import BatchNormalization
from keras.layers import Conv2DTranspose
from keras.layers import LeakyReLU
from keras.layers import Activation
from keras.layers import Dropout
from keras.layers import Reshape
from keras.layers import Dense
from keras.layers import Flatten
from keras.optimizers import Adam
from keras.utils.generic_utils import Progbar

class DCGAN(object):

    def __init__(self,config,discriminator_weight_path = None,dcgan_weight_path=None):
        """
        这是DCGAN的初始化函数
        :param config: 网络模型参数配置类
        :param discriminator_weight_path: 网络模型参数配置类
        :param dcgan_weight_path: 网络模型参数配置类
        """
        # 初始化网络相关超参数类
        self.config = config

        # 构建生成器与判别器
        self.generotor_model = self.build_generator_model()
        self.discriminator_model = self.build_discriminator_model()

        # 构建DCGAN的优化器,并编译判别器
        self.optimizier = Adam(lr=self.config.init_learning_rate,
                               beta_1=self.config.beta1,
                               decay=1e-8)
        if discriminator_weight_path is not None:
            self.discriminator_model.load_weights(discriminator_weight_path,by_name=True)
        self.discriminator_model.compile(loss='binary_crossentropy',
                                         optimizer=self.optimizier)

        # 构建DCGAN模型并进行编译
        dcgan_input = Input(shape=self.config.generator_input_dim)
        dcgan_output = self.discriminator_model(self.generotor_model(dcgan_input))
        self.discriminator_model.trainable = False

        self.dcgan = Model(dcgan_input,dcgan_output)
        if dcgan_weight_path is not None:
            self.dcgan.load_weights(dcgan_weight_path,by_name=True)
        self.dcgan.compile(optimizer=self.optimizier, loss='binary_crossentropy')

    def build_generator_model(self):
        """
        这是构建生成器网络的函数
        :return:返回生成器模型generotor_model
        """
        noise = Input(shape=self.config.generator_input_dim, name="generator_input")

        x = Dense(256*7*7,input_shape=self.config.generator_input_dim,name="dense1")(noise)
        x = BatchNormalization(momentum=self.config.BatchNormalization_Momentum,name="bn1")(x)
        x = Activation('relu',name="relu1")(x)
        x = Reshape((7,7,256),name="reshape")(x)

        x = Conv2DTranspose(128,kernel_size=3,strides=2,padding='same',name="deconv1")(x)
        x = BatchNormalization(momentum=self.config.BatchNormalization_Momentum,name="bn2")(x)
        x = Activation('relu',name="relu2")(x)

        x = Conv2DTranspose(64, kernel_size=3, strides=2, padding='same',name="deconv2")(x)
        x = BatchNormalization(momentum=self.config.BatchNormalization_Momentum,name="bn3")(x)
        x = Activation('relu',name="relu3")(x)

        x = Conv2DTranspose(32, kernel_size=3,padding='same',name="deconv3")(x)
        x = BatchNormalization(momentum=self.config.BatchNormalization_Momentum,name="bn4")(x)
        x = Activation('relu',name="relu4")(x)

        x = Conv2DTranspose(self.config.discriminator_input_dim[2], kernel_size=3,padding='same',name="deconv4")(x)
        x = Activation('tanh',name="generator_output")(x)

        model = Model(noise,x)
        model.summary()

        return model

    def build_discriminator_model(self):
        """
        这是构造判别器模型的函数
        :return: 返回判别器模型discriminator_model
        """
        image = Input(shape=self.config.discriminator_input_dim, name="discriminator_input")

        x = Conv2D(64,kernel_size=3,strides=2,padding='same',name="conv1")(image)
        x = LeakyReLU(self.config.LeakyReLU_alpha,name="leakyrelu1")(x)
        x = Dropout(self.config.dropout_prob,name="dropout1")(x)

        x = Conv2D(128,kernel_size=3,strides=2,padding='same',name="conv2")(x)
        x = LeakyReLU(self.config.LeakyReLU_alpha,name="leakyrelu2")(x)
        x = Dropout(self.config.dropout_prob,name="dropout2")(x)

        x = Conv2D(256,kernel_size=3,strides=2,padding='same',name="conv3")(x)
        x = LeakyReLU(self.config.LeakyReLU_alpha,name="leakyrelu3")(x)
        x = Dropout(self.config.dropout_prob,name="dropout3")(x)

        x = Conv2D(512,kernel_size=3,strides=2,padding='same',name="conv4")(x)
        x = LeakyReLU(self.config.LeakyReLU_alpha,name="leakyrelu4")(x)
        x = Dropout(self.config.dropout_prob,name="dropout4")(x)

        x = Flatten(name="flatten1")(x)
        x = Dense(1,name="dense")(x)
        x = Activation('sigmoid',name="discriminator_output")(x)

        model = Model(image,x)
        model.summary()

        return model

    def train(self,train_datagen,epoch,k,batch_size=256):
        """
        这是DCGAN的训练函数
        :param train_generator:训练数据生成器
        :param epoch:训练周期
        :param batch_size:小批量样本规模
        :param k:训练判别器次数
        :return:
        """
        half_batch = int(batch_size/2)
        length = train_datagen.get_batch_num()
        for ep in np.arange(1,epoch+1):
            dcgan_losses = []
            d_losses = []
            probar = Progbar(length)
            print("Epoch {}/{}".format(ep,epoch))
            iter = 0
            while True:
                # 数据集遍历完成停止循环
                if train_datagen.get_epoch() != ep:
                    break

                iter +=1

                d_loss = []
                for i in np.arange(k):
                    # 获取真实图片及其标签
                    batch_real_images = train_datagen.next_batch()
                    batch_real_images_labels = truncnorm.rvs(0.7, 1.2, size=(half_batch, 1))
                    # 生成一个batch_size的噪声用于生成图片,并制造标签
                    batch_noise = truncnorm.rvs(-1,1,size = (half_batch , self.config.generator_input_dim[0]))
                    batch_gen_images = self.generotor_model.predict(batch_noise)
                    batch_gen_images_labels = truncnorm.rvs(0.0, 0.3, size=(half_batch, 1))

                    # 合并真图与假图及其对应的标签
                    #print(np.shape(batch_gen_images))
                    #print(np.shape(batch_real_images))
                    batch_images = np.concatenate([batch_gen_images, batch_real_images],axis=0)
                    batch_images_labels = np.concatenate((batch_gen_images_labels,batch_real_images_labels))
                    # 训练判别器
                    _d_loss = self.discriminator_model.train_on_batch(batch_images,batch_images_labels)
                    d_loss.append(_d_loss)
                d_loss = np.average(d_loss)

                # 生成一个batch_size的噪声来训练生成器
                batch_noise = truncnorm.rvs(-1,1,size=(half_batch ,self.config.generator_input_dim[0]))
                batch_noise_label = truncnorm.rvs(0.7,1.2,size=(half_batch ,1))
                dcgan_loss = self.dcgan.train_on_batch(batch_noise,batch_noise_label)

                dcgan_losses.append(dcgan_loss)
                d_losses.append(d_loss)

                # 更新进度条
                probar.update(iter,[("dcgan_loss",np.average(dcgan_losses[:iter])),
                                    ("discriminator_loss",np.average(d_losses[:iter]))])

            if int(ep % self.config.save_interval) == 0:
                dcgan_model = "Epoch%ddcgan_loss%.5fdiscriminator_loss%.5f.h5" \
                              % (ep,np.average(dcgan_losses), np.average(d_losses))
                discriminator_model = "Epoch%ddiscriminator_loss%.5f.h5" \
                              % (ep,np.average(d_losses))
                #self.dcgan.save(os.path.join(self.config.save_weight_dir,'dcgan.h5'))
                self.dcgan.save(os.path.join(self.config.save_weight_dir,dcgan_model))
                self.discriminator_model.save(os.path.join(self.config.save_weight_dir, discriminator_model))
                self.save_image(epoch)

    def save_image(self,epoch):
        """
        这是保存生成图片的函数
        :param epoch:周期数,用于图片命名需要
        :return:
        """
        rows, cols = 5, 5
        images = self.generator_batch_images(rows*cols)         # 生成批量数据

        fig, axs = plt.subplots(rows, cols)
        cnt = 0
        for i in range(rows):
            for j in range(cols):
                axs[i, j].imshow(images[cnt, :, :, 0], cmap='gray')
                axs[i, j].axis('off')
                cnt += 1
        fig.savefig(os.path.join(self.config.result_path,"mnist-{0:0>5}.png".format(epoch)), dpi=600)
        plt.close()

    def generator_batch_images(self,batch_size):
        """
        这是生成批量规模图像的函数
        :param batch_size: 批量规模
        :return:
        """
        # 生成一个batch_size的噪声用于生成图片
        batch_noise = truncnorm.rvs(-1, 1, size=(batch_size, self.config.generator_input_dim[0]))
        batch_gen_images = self.generotor_model.predict(batch_noise)
        return batch_gen_images

    def generator_image(self):
        """
        这是生成批量规模图像的函数
        :param batch_size: 批量规模
        :return:
        """
        image = self.generator_batch_images(1)[0]
        return image

三 利用DCGAN生成手写数字

接下来我们给出,DCGAN训练与生成手写数字的程序,如下所示。在这份代码中,我们首先构造了一个属于mnist数据集的参数配置类MnistConfig,该类继承自基本参数配置类Config。Config的代结构请详见github链接:DCGAN-mnist,在此我们不在具体给给出。

# -*- coding: utf-8 -*-
# @Time    : 2019/9/15 21:58
# @Author  : DaiPuWei
# @Email   : [email protected]
# @Blog    : https://daipuweiai.blog.csdn.net/
# @File    : train.py
# @Software: PyCharm

import os
import datetime
from Config.Config import Config
from DCGAN.DCGAN import DCGAN
from MnistGenerator.MnistGenerator import MnistGenerator

class MnistConfig(Config):

    def __init__(self):
        #super(Config, self).__init__()
        Config.__init__(self)
        time = datetime.datetime.now().strftime("%Y%m%d%H%M%S")
        self.save_weight_dir = os.path.join(Config.get_save_weight_dir(self),time)
        if not os.path.exists(self.save_weight_dir):
            os.mkdir(self.save_weight_dir)
        self.result_path = os.path.join(Config.get_result_path(self),time)
        if not os.path.exists(self.result_path):
            os.mkdir(self.result_path)
        self.batch_size = 256

        print("模型保存在:{}".format(self.save_weight_dir))
        print("训练结果保存在:{}".format(self.result_path))


def run_main():
    """
       这是主函数
    """

    # 训练模型
    cfg =  MnistConfig()
    dcgan = DCGAN(cfg)
    train_datagen = MnistGenerator(int(cfg.batch_size/2))
    dcgan.train(train_datagen,1000,20,cfg.batch_size)      # 训练模型

if __name__ == '__main__':
    run_main()
 

接下来我们给出训练过程中的实验结果。50次迭代:
在这里插入图片描述

第100次迭代:
在这里插入图片描述

第500次迭代:
在这里插入图片描述
第1000次迭代:
在这里插入图片描述

第5000次迭代:
在这里插入图片描述

第10000次迭代:
在这里插入图片描述


四、DCGAN训练小技巧

CNN的训练过程主要就是根据损失函数利用梯度下降及其改进算法进行训练,更新网络参数。但是不同于CNN的训练,GAN的训练是一个动态的过程,GAN的目标是寻求判别器与生成器之间的动态平衡。因此我们不能只靠梯度下降算法进行训练模型。

在DCGAN训练过程中有如下几点小技巧可以直接采纳:

  1. 能用Adam优化器的情况下尽量使用Adam优化器,不行的话使用RMSprop优化器。
  2. DCGAN是无监督学习算法,不需要网络学习标签,但是在keras中为了区分真图与生成器生成的假图,我们必须用标签进行区分,原则上假图应该全部赋值为0,真图赋值为1,但是为了使得模型更加鲁棒,通常使用0-0.3的随随机数代表假图,0.7-1.2的随机数代表真图,随机数的生成都是用截断正态分布,具体实现需导入:from scipy.stats import truncnorm
  3. DCGAN训练过程中,通常是训练判别器,然后训练整个DCGAN。并且DCGAN训练中,判别器不能更新参数,因此必须冻结所有层。为了加快网络训练,通常是训练多次判别器,然后训练一次生成器,这样能提高网络训练速率。

后记

至此,GAN系列的第一篇到此完全结束。在这一篇文章中,我们领略了DCGAN的强大。接下来我们原始GAN开始进行讲解GAN发展。敬请期待GAN系列第二篇:GAN论文详解。

猜你喜欢

转载自blog.csdn.net/qq_30091945/article/details/101036655