GAN网络生成手写体数字图片

Keras真香,以前都是用tensorflow来写神经网络,自从用了keras,发现这个Keras也蛮方便的。
目前感觉keras的优点就是方便搭建基于标准网络组件的神经网络,这里的网络组件包括全连接层,卷积层,池化层等等。对于需要对网络本身做创新的实验,keas可能不是很方便,还是得用tensorflow来搭建。

这篇博客,我想用Keras写一个简单的生成对抗网络。
生成对抗网络的目标是生成手写体数字。

先看看实验的效果:
epoch=1000的时候:
在这里插入图片描述
epoch=10000的时候:数字1已经有点像了
在这里插入图片描述
epoch=60000,数字1就很清晰了 ,而且其他数字也越来越清晰了
在这里插入图片描述
epoch=80000: 生成了5,7 啥的了。
在这里插入图片描述
随着训练的加深,生成的数字会越来越真实了。
代码已经开源,项目地址:

https://github.com/jmhIcoding/GAN_MNIST.git

模型原理

模型原理就不说了,就是使用最基础GAN结构。
模型由一个生成器和一个鉴别器组成。
生成器用于输入噪声,然后生成一个手写体数字图片。
鉴别器用于判断某个输入给它的图片是不是生成器合成的。

生成器的目标是生成让鉴别器判断为非合成的图片。
鉴别器的目标则是以尽量高的正确率分类某种图片是否为合成的。

总的原理就是这些了。
模型的损失函数就是围绕着这两个目标来展开的。

模型编写

生成器

__author__ = 'dk'
#生成器

import sys
import numpy as np

import  keras
from  keras import layers
from keras import models
from  keras import optimizers
from keras import losses

class Generator:
    def __init__(self,height=28,width=28,channel=1,latent_space_dimension=100):
        '''
        :param height:    生成图片的高,minist为28
        :param width:     生成图片的宽,minist为28
        :param channel:   生成器所生成的图片的通道数目,对于mnist灰度图来说,channel为1
        :param latent_space_dimension:  噪声的维度
        :return:
        '''

        self.latent_space_dimension = latent_space_dimension
        self.height = height
        self.width = width
        self.channel = channel
        self.generator = self.build_model()
        self.generator.summary()
    def build_model(self,block_starting_size=128,num_blocks=4):
        model = models.Sequential(name='generator')
        for i in range(num_blocks):
            if i ==0 :
                model.add(layers.Dense(block_starting_size,input_shape=(self.latent_space_dimension,)))
            else:
                block_size = block_starting_size * (2**i)
                model.add(layers.Dense(block_size))
                model.add(layers.LeakyReLU())
                model.add(layers.BatchNormalization(momentum=0.75))

        model.add(layers.Dense(self.height*self.channel*self.width,activation='tanh'))
        model.add(layers.Reshape((self.width,self.height,self.channel)))
        return  model
    def summary(self):
        self.model.summary()

    def save_model(self):
        self.generator.save("generator.h5")

注意,generator是和整个模型一起训练的,它可以不需要compile模型。

鉴别器

__author__ = 'dk'
#判别器
import sys
import os
import keras
from  keras import layers
from keras import optimizers
from keras import models
from keras import losses
class Discriminator:
    def __init__(self,height=28,width=28,channel=1):
        '''
        
        :param height:  输入图片的高
        :param width:   输入图片的宽
        :param channel: 输入图片的通道数
        :return:
        '''
        self.height = height
        self.width = width
        self.channel = channel
        self.discriminator = self.build_model()
        OPTIMIZER = optimizers.Adam()
        self.discriminator = self.build_model()
        self.discriminator.compile(optimizer=OPTIMIZER,loss=losses.binary_crossentropy,metrics =['accuracy'])
        self.discriminator.summary()
    def build_model(self):
        model = models.Sequential(name='discriminator')
        model.add(layers.Flatten(input_shape=(self.width,self.height,self.channel)))
        model.add(layers.Dense(self.height*self.width*self.channel,input_shape=(self.width,self.height,self.channel)))
        model.add(layers.LeakyReLU(0.2))
        model.add(layers.Dense(self.height*self.width*self.channel//2))
        model.add(layers.LeakyReLU(0.2))
        model.add(layers.Dense(1,activation='sigmoid'))
        return model

    def summary(self):
        return self.discriminator.summary()

    def save_model(self):
        self.discriminator.save("discriminator.h5")

gan网络

把生成器和鉴别器合并起来

__author__ = 'dk'
#生成对抗网络

import keras
from keras import layers
from  keras import optimizers
from  keras import  losses
from  keras import models

import  sys
import os

from Discriminator import Discriminator
from Generator import Generator
class GAN:
    def __init__(self,latent_space_dimension,height,width,channel):
        self.generator  = Generator(height,width,channel,latent_space_dimension)
        self.discriminator = Discriminator(height,width,channel)
        self.discriminator.discriminator.trainable = False 
        #gan部分,只训练生成器,鉴别器通过显式discriminator.train_on_batch调用来训练
        self.gan =  self.build_model()
        OPTIMIZER = optimizers.Adamax()
        self.gan.compile(optimizer = OPTIMIZER,loss = losses.binary_crossentropy)
        self.gan.summary()
    def build_model(self):
        model  = models.Sequential(name='gan')
        model.add(self.generator.generator)
        model.add(self.discriminator.discriminator)
        return  model
    def summary(self):
        self.gan.summary()

    def save_model(self):
        self.gan.save("gan.h5")

数据准备模块

__author__ = 'dk'
#数据集采集器,主要是对mnist进行简单的封装
from keras.datasets import mnist
import numpy as np
def sample_latent_space(instances_number,latent_space_dimension):
    return  np.random.normal(0,1,(instances_number,latent_space_dimension))

class Dator:
    def __init__(self,batch_size=None,model_type=1):
        '''

        :param batch_size:
        :param model_type:  当model_type为-1的时候,表示0-9个数字都选;当model_type=2,说明只选择数字2
        :return:
        '''
        self.batch_size = batch_size
        self.model_type = model_type
        with np.load("mnist.npz", allow_pickle=True) as f:
            X_train, y_train = f['x_train'], f['y_train']
            #X_test, y_test = f['x_test'], f['y_test']
        if model_type != -1:
            X_train = X_train[np.where(y_train==model_type)[0]]
        if batch_size == None:
            self.batch_size = X_train.shape[0]
        else:
            self.batch_size = batch_size

        self.X_train = (np.float32(X_train)-128)/128.0
        self.X_train = np.expand_dims(self.X_train,3)

        self.watch_index = 0
        self.train_size = self.X_train.shape[0]
    def next_batch(self,batch_size = None):
        if batch_size == None:
            batch_size  =self.batch_size

        X=np.concatenate([self.X_train[self.watch_index:(self.watch_index+batch_size)], self.X_train[:batch_size]])[:batch_size]
        self.watch_index  = (self.watch_index + batch_size) % self.train_size
        return  X

if __name__ == '__main__':
    print(sample_latent_space(5,4))

训练main脚本:train.py

__author__ = 'dk'
#模型训练代码
from  GAN import GAN
from data_utils import Dator,sample_latent_space
import  numpy as np
from matplotlib import pyplot as plt
import time

epochs = 50000
height = 28
width = 28
channel =1
latent_space_dimension = 100
batch = 128
dator = Dator(batch_size=batch,model_type=-1)
gan = GAN(latent_space_dimension,height,width,channel)
image_index = 0
for i in range(epochs):
    real_img = dator.next_batch(batch_size=batch*2)
    real_label = np.ones(shape=(real_img.shape[0],1))       #真实的样本设置为1的标签

    noise = sample_latent_space(real_img.shape[0],latent_space_dimension)
    fake_img = gan.generator.generator.predict(noise)
    fake_label = np.zeros(shape=(fake_img.shape[0],1))      #生成器生成的假图片标注为0

    ###合成给gan的鉴别器的数据
    x_batch = np.concatenate([real_img,fake_img])
    y_batch = np.concatenate([real_label,fake_label])
    #训练一次
    discriminator_loss = gan.discriminator.discriminator.train_on_batch(x_batch,y_batch)[0]
    ###注意,此时训练的是鉴别器,生成器部分不动。
    ###合成训练生成器的数据
    noise = sample_latent_space(batch*2,latent_space_dimension)
    noise_labels = np.ones((batch*2,1))           
    #生成器的目标是把图片的label越来越像1

    generator_loss = gan.gan.train_on_batch(noise,noise_labels)

    print('Epoch : {0}, [Discriminator Loss:{1} ], [Generator Loss:{2}]'.format(i,discriminator_loss,generator_loss))

    if i!=0 and (i%50)==0:
        print('show time')
        #每50次输入16张图片看看效果
        noise = sample_latent_space(16,latent_space_dimension)
        images = gan.generator.generator.predict(noise)
        plt.figure(figsize=(10,10))
        plt.suptitle('epoch={0}'.format(i),fontsize=16)
        for index in range(images.shape[0]):
            plt.subplot(4,4,index+1)
            image  =images[index,:,:,:]
            image = image.reshape(height,width)
            plt.imshow(image,cmap='gray')
        #plt.tight_layout()
        plt.savefig("./show_time/{0}.png".format(time.time()))
        image_index += 1
        plt.close()


运行脚本

python3 train.py 

即可。
输出:

Model: "generator"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
dense_1 (Dense)              (None, 128)               12928     
_________________________________________________________________
dense_2 (Dense)              (None, 256)               33024     
_________________________________________________________________
leaky_re_lu_1 (LeakyReLU)    (None, 256)               0         
_________________________________________________________________
batch_normalization_1 (Batch (None, 256)               1024      
_________________________________________________________________
dense_3 (Dense)              (None, 512)               131584    
_________________________________________________________________
leaky_re_lu_2 (LeakyReLU)    (None, 512)               0         
_________________________________________________________________
batch_normalization_2 (Batch (None, 512)               2048      
_________________________________________________________________
dense_4 (Dense)              (None, 1024)              525312    
_________________________________________________________________
leaky_re_lu_3 (LeakyReLU)    (None, 1024)              0         
_________________________________________________________________
batch_normalization_3 (Batch (None, 1024)              4096      
_________________________________________________________________
dense_5 (Dense)              (None, 784)               803600    
_________________________________________________________________
reshape_1 (Reshape)          (None, 28, 28, 1)         0         
=================================================================
Total params: 1,513,616
Trainable params: 1,510,032
Non-trainable params: 3,584
_________________________________________________________________
Model: "discriminator"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
flatten_2 (Flatten)          (None, 784)               0         
_________________________________________________________________
dense_9 (Dense)              (None, 784)               615440    
_________________________________________________________________
leaky_re_lu_6 (LeakyReLU)    (None, 784)               0         
_________________________________________________________________
dense_10 (Dense)             (None, 392)               307720    
_________________________________________________________________
leaky_re_lu_7 (LeakyReLU)    (None, 392)               0         
_________________________________________________________________
dense_11 (Dense)             (None, 1)                 393       
=================================================================
Total params: 923,553
Trainable params: 923,553
Non-trainable params: 0
_________________________________________________________________
Model: "gan"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
generator (Sequential)       (None, 28, 28, 1)         1513616   
_________________________________________________________________
discriminator (Sequential)   (None, 1)                 923553    
=================================================================
Total params: 2,437,169
Trainable params: 1,510,032
Non-trainable params: 927,137
_________________________________________________________________
····
···
··

Epoch : 117754, [Discriminator Loss:0.22975191473960876 ], [Generator Loss:2.57688570022583]
Epoch : 117755, [Discriminator Loss:0.26782122254371643 ], [Generator Loss:3.1791584491729736]
Epoch : 117756, [Discriminator Loss:0.2609345614910126 ], [Generator Loss:2.960988998413086]
Epoch : 117757, [Discriminator Loss:0.2673880159854889 ], [Generator Loss:2.317220687866211]
Epoch : 117758, [Discriminator Loss:0.24904575943946838 ], [Generator Loss:1.929720401763916]
Epoch : 117759, [Discriminator Loss:0.25158950686454773 ], [Generator Loss:2.954155683517456]
Epoch : 117760, [Discriminator Loss:0.20324105024337769 ], [Generator Loss:3.5244760513305664]
Epoch : 117761, [Discriminator Loss:0.2849388122558594 ], [Generator Loss:3.195873498916626]
Epoch : 117762, [Discriminator Loss:0.19631560146808624 ], [Generator Loss:2.328411340713501]
Epoch : 117763, [Discriminator Loss:0.20523831248283386 ], [Generator Loss:2.402683973312378]
Epoch : 117764, [Discriminator Loss:0.2625979781150818 ], [Generator Loss:3.2176101207733154]
Epoch : 117765, [Discriminator Loss:0.29969191551208496 ], [Generator Loss:2.9656052589416504]
Epoch : 117766, [Discriminator Loss:0.270328551530838 ], [Generator Loss:2.3880398273468018]
Epoch : 117767, [Discriminator Loss:0.26741161942481995 ], [Generator Loss:2.7729406356811523]
Epoch : 117768, [Discriminator Loss:0.28797847032546997 ], [Generator Loss:2.8959264755249023]
Epoch : 117769, [Discriminator Loss:0.30181047320365906 ], [Generator Loss:2.791097402572632]
Epoch : 117770, [Discriminator Loss:0.26939862966537476 ], [Generator Loss:2.3666043281555176]
Epoch : 117771, [Discriminator Loss:0.26297527551651 ], [Generator Loss:2.895970582962036]
Epoch : 117772, [Discriminator Loss:0.21928083896636963 ], [Generator Loss:3.4627976417541504]
Epoch : 117773, [Discriminator Loss:0.3553962707519531 ], [Generator Loss:3.2194197177886963]
Epoch : 117774, [Discriminator Loss:0.32673510909080505 ], [Generator Loss:2.473867893218994]
Epoch : 117775, [Discriminator Loss:0.31245478987693787 ], [Generator Loss:2.999265193939209]
Epoch : 117776, [Discriminator Loss:0.29536381363868713 ], [Generator Loss:3.733344554901123]
Epoch : 117777, [Discriminator Loss:0.2955515682697296 ], [Generator Loss:3.2467658519744873]
Epoch : 117778, [Discriminator Loss:0.3677394986152649 ], [Generator Loss:1.8517814874649048]
Epoch : 117779, [Discriminator Loss:0.31648850440979004 ], [Generator Loss:2.6385254859924316]
Epoch : 117780, [Discriminator Loss:0.31941041350364685 ], [Generator Loss:3.350475311279297]
Epoch : 117781, [Discriminator Loss:0.47521263360977173 ], [Generator Loss:1.9556307792663574]
Epoch : 117782, [Discriminator Loss:0.44070643186569214 ], [Generator Loss:1.9684114456176758]

猜你喜欢

转载自blog.csdn.net/jmh1996/article/details/106414076
今日推荐