生死看淡,不服就GAN(五)----用DCGAN生成MNIST手写体

搭建DCGAN网络

#*************************************** 生死看淡,不服就GAN **************************************************************
"""
PROJECT:MNIST_DCGAN
Author:Ephemeroptera
Date:2018-4-25
QQ:605686962
Reference:' improved_wgan_training-master': <https://github.com/igul222/improved_wgan_training>
           'Zardinality/WGAN-tensorflow':<https://github.com/Zardinality/WGAN-tensorflow>
           'NELSONZHAO/zhihu':<https://github.com/NELSONZHAO/zhihu>
"""
"""
Note: in this section , we add batch-normalization-laysers in G\D to acclerate training.Additionally,we use 
      moving average model to G to get well products from G
"""
# import dependency
import tensorflow as tf
import numpy as np
import pickle
import visualization
import os
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation
from threading import Thread
import time
from time import sleep
import cv2

# import MNIST dataset
mnist_dir = r'../mnist_dataset'
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets(mnist_dir)

#--------------------------------------- define moudle -----------------------------------------------------------------

# deconv
def deconv(img,new_size,fmaps,name='deconv'):
    with tf.variable_scope(name):
        img = tf.image.resize_nearest_neighbor(img,new_size,name='upscale')# upscale
        return tf.layers.conv2d(img,fmaps,3,padding='SAME',name='conv2d')

# define Generator
def Generator_DC_28x28(latents,is_train):

    with tf.variable_scope("generator",reuse=(not is_train)):
        # dense0 ,size = (4,4), fmaps = 512
        dense0 = tf.layers.dense(latents,4*4*512,name='dense0')
        dense0 = tf.reshape(dense0,[-1,4,4,512])
        dense0 = tf.layers.batch_normalization(dense0, training=is_train)
        dense0 = tf.nn.leaky_relu(dense0)
        dense0 = tf.layers.dropout(dense0,rate=0.2)
        a = tf.get_variable_scope().name

        # deconv0 , size = (7,7) , fmaps = 256
        deconv1 = deconv(dense0,(7,7),256,name='deconv1')
        deconv1 = tf.layers.batch_normalization(deconv1, training=is_train)
        deconv1 = tf.nn.leaky_relu(deconv1)
        deconv1 = tf.layers.dropout(deconv1,rate=0.2)

        # deconv1 , size = (14,14) , fmaps = 128
        deconv2 = deconv(deconv1, (14, 14), 128, name='deconv2')
        deconv2 = tf.layers.batch_normalization(deconv2, training=is_train)
        deconv2 = tf.nn.leaky_relu(deconv2)
        deconv2 = tf.layers.dropout(deconv2, rate=0.2)

        # deconv2 , size = (28,28) , fmaps = 64
        deconv3 = deconv(deconv2, (28, 28), 64, name='deconv3')
        deconv3 = tf.layers.batch_normalization(deconv3, training=is_train)
        deconv3 = tf.nn.leaky_relu(deconv3)
        deconv3 = tf.layers.dropout(deconv3, rate=0.2)

        # toimg , size = (28,28) , fmaps = 1
        toimg = tf.layers.conv2d(deconv3,1,3,padding='SAME',bias_initializer=tf.zeros_initializer,
                                 activation=tf.nn.tanh,name='toimg')

    return toimg

# define Discriminator
def Discriminator_DC_28x28(img,reuse = False):

    with tf.variable_scope("discriminator", reuse=reuse):
        # conv0 , size=(14,14) , fmaps =128
        conv0 = tf.layers.conv2d(img,128,3,padding='SAME',activation=tf.nn.leaky_relu,
                                 kernel_initializer=tf.random_normal_initializer(0,1), name='conv0')
        conv0 = tf.layers.average_pooling2d(conv0,2,2,padding='SAME',name='pool0')

        # conv1 , size=(7,7) , fmaps =256
        conv1 = tf.layers.conv2d(conv0, 256, 3, padding='SAME',
                                 kernel_initializer=tf.random_normal_initializer(0, 1), name='conv1')
        conv1 = tf.layers.batch_normalization(conv1,training=True)
        conv1 = tf.nn.leaky_relu(conv1)
        conv1 = tf.layers.average_pooling2d(conv1, 2, 2, padding='SAME', name='pool1')

        # conv2 , size=(5,5) , fmaps =512
        conv2 = tf.layers.conv2d(conv1, 512, 3, padding='VALID',
                                 kernel_initializer=tf.random_normal_initializer(0, 1), name='conv2')
        conv2 = tf.layers.batch_normalization(conv2, training=True)
        conv2 = tf.nn.leaky_relu(conv2)

        # dense3 ,size = 5*5*512
        dense3 = tf.reshape(conv2,[-1,5*5*512])
        dense3 = tf.layers.dense(dense3,1,name='dense3')
        outputs = tf.nn.sigmoid(dense3)

    return dense3,outputs

# counting total to vars
def COUNT_VARS(vars):
    total_para = 0
    for variable in vars:
        # get each shape of vars
        shape = variable.get_shape()
        variable_para = 1
        for dim in shape:
            variable_para *= dim.value
        total_para += variable_para
    return total_para

# display paras infomation
def ShowParasList(paras):
    p = open('./trainLog/Paras.txt', 'w')
    p.writelines(['vars_total: %d'%COUNT_VARS(paras),'\n'])
    for variable in paras:
        p.writelines([variable.name, str(variable.get_shape()),'\n'])
        print(variable.name, variable.get_shape())
    p.close()

# build related dirs
def GEN_DIR():
    if not os.path.isdir('ckpt'):
        print('DIR:ckpt NOT FOUND,BUILDING ON CURRENT PATH..')
        os.mkdir('ckpt')
    if not os.path.isdir('trainLog'):
        print('DIR:ckpt NOT FOUND,BUILDING ON CURRENT PATH..')
        os.mkdir('trainLog')


#------------------------------------------ define grath ---------------------------------------------------------------
# hyper-parameters
latents_dim = 128
smooth = 0.1

# define inputs
latents = tf.placeholder(shape=[None,latents_dim],dtype=tf.float32,name='latents')
input_real = tf.placeholder(shape=[None,28,28,1],dtype=tf.float32,name='input_real')

# outputs from G\D
# from Generator
g_outputs = Generator_DC_28x28(latents,is_train=True)
g_test = Generator_DC_28x28(latents,is_train=False)
# from Discriminator
d_logits_real, d_outputs_real = Discriminator_DC_28x28(input_real,reuse=False)
d_logits_fake, d_outputs_fake = Discriminator_DC_28x28(g_outputs,reuse=True)

# define loss
d_loss_real = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=d_logits_real,
                                                                     labels=tf.ones_like(d_logits_real)) * (1 - smooth))
d_loss_fake = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=d_logits_fake,
                                                                         labels=tf.zeros_like(d_logits_fake)))
d_loss = tf.add(d_loss_real, d_loss_fake)
g_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=d_logits_fake,
                                                                    labels=tf.ones_like(d_logits_fake)) * (1 - smooth))

#-------------------------------------------- Paras Display ------------------------------------------------------------
# list  trainable varilables
train_vars = tf.trainable_variables()
# seperate into d/g
d_train_vars = [var for var in train_vars if var.name.startswith("discriminator")]
g_train_vars = [var for var in train_vars if var.name.startswith("generator")]
# add g_train_vars to 'G_RAW'
for var in g_train_vars:
    tf.add_to_collection('G_RAW',var)

# list all varilables
all_vars = tf.all_variables()
# get all vars of G
g_all_vars = [var for var in all_vars if var.name.startswith("generator")]
# get μ,σ from BN of G
g_bn_m_v = [var for var in g_all_vars if 'moving_mean' in var.name]
g_bn_m_v += [var for var in g_all_vars if 'moving_variance' in var.name]
# add to 'G_BN_MV'
for var in g_bn_m_v:
    tf.add_to_collection('G_BN_MV',var)


#--------------------------------------------- Gradient Descent -------------------------------------------------------
# training parameters
learn_rate = 2e-4
G_step = tf.Variable(0, trainable=False)
D_step = tf.Variable(0, trainable=False)

# Gradient Descent
with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)): # updating mean and std of batch first
    d_train_opt = tf.train.AdamOptimizer(learn_rate,beta1=0.5).minimize(d_loss, var_list=d_train_vars,global_step=D_step)
    g_train_opt = tf.train.AdamOptimizer(learn_rate,beta1=0.5).minimize(g_loss, var_list=g_train_vars,global_step=G_step)

#---------------------------------------- Exponential Moving Average for G ---------------------------------------------
# apply EMA
G_averages = tf.train.ExponentialMovingAverage(0.999, G_step)
gvars_averages_op = G_averages.apply(g_train_vars) # apply ema
# get shadow
g_vars_ema = [G_averages.average(var) for var in g_train_vars] # g_train_vars using ema
# add g_vars_ema to 'G_EMA'
for ema in g_vars_ema:
    tf.add_to_collection('G_EMA',ema)

# confirm training first and collect paras of G into [tf.GraphKeys.MOVING_AVERAGE_VARIABLES] later
with tf.control_dependencies([g_train_opt,gvars_averages_op]):
    g_train_opt_ema = tf.no_op(name='g_train_opt_ema')

#---------------------------------------------- iteration --------------------------------------------------------------
# setting
max_iters = 5000
batch_size = 50
critic_n = 1
# for recording
GEN_DIR()
GenLog = []
Losses = []
saver = tf.train.Saver(var_list=g_train_vars+g_vars_ema+g_bn_m_v) # saving raw and ema

# recording training info
def SavingRecords():
    global Losses
    global GenLog
    # saving Losses
    with open('./trainLog/loss_variation.loss', 'wb') as l:
        losses = np.array(Losses)
        pickle.dump(losses, l)
        print('saving Losses sucessfully!')
    # saving genlog
    with open('./trainLog/GenLog.log', 'wb') as g:
        GenLog = np.array(GenLog)
        pickle.dump(GenLog, g)
        print('saving GenLog sucessfully!')

# define training
def training():
    # run
    with tf.Session() as sess:
        # init
        init = (tf.global_variables_initializer(), tf.local_variables_initializer())
        sess.run(init)

        # Go
        time_start = time.time()  # go
        for steps in range(max_iters+1):

            # get batch
            data_batch = mnist.train.next_batch(batch_size)[0]
            # format modification
            data_batch = np.reshape(data_batch,[-1,28,28,1])
            # visualization.CV2_BATCH_RANDOM_SHOW(data_batch,1,25,5,5,0)
            data_batch = data_batch * 2 - 1
            data_batch = data_batch.astype(np.float32)
            # get latents
            z = np.random.normal(0, 1, size=[batch_size, latents_dim]).astype(np.float32)

            # training discriminator
            for n in range(critic_n):
                sess.run(d_train_opt, feed_dict={
    
    input_real: data_batch, latents: z})

            # training generator
            sess.run(g_train_opt_ema, feed_dict={
    
    input_real: data_batch,latents: z})

            # recording training_losses
            train_loss_d = sess.run(d_loss, feed_dict={
    
    input_real: data_batch, latents: z})
            train_loss_g = sess.run(g_loss, feed_dict={
    
    latents: z})
            info = [steps, train_loss_d, train_loss_g]

            # recording training_products
            gen_sanmpes = sess.run(g_outputs, feed_dict={
    
    latents: z})
            visualization.CV2_BATCH_SHOW((gen_sanmpes[0:9] + 1) / 2, 0.5, 3, 3, delay=1)
            print('iters::%d/%d..Discriminator_loss:%.3f..Generator_loss:%.3f..' % (
                                                                        steps, max_iters, train_loss_d, train_loss_g))

            if steps % 5 == 0:
                Losses.append(info)
                GenLog.append(gen_sanmpes)

            if steps % 1000 == 0 and steps > 0:
                saver.save(sess, './ckpt/generator.ckpt', global_step=steps)

            if steps == max_iters:
                # cv2.destroyAllWindows()
                # setup a thread to saving the training info
                time_over = time.time()
                print('iterating is over! consuming time :%.2f'%(time_over-time_start))
                sleep(3)
                thread1 = Thread(target=SavingRecords, args=())
                thread1.start()

            yield info

#------------------------------------------------ ANIMATION ------------------------------------------------------------
# ANIMATION
"""
note: in this code , we will see the runtime-variation of G,D losses
"""
iters = []
dloss = []
gloss = []
fig = plt.figure('runtime-losses')
ax1 = fig.add_subplot(2,1,1,xlim=(0, max_iters), ylim=(-10, 10))
ax2 = fig.add_subplot(2,1,2,xlim=(0, max_iters), ylim=(-20, 20))
ax1.set_title('discriminator_loss')
ax2.set_title('generator_loss')
line1, = ax1.plot([], [], color='red',lw=1,label='discriminator')
line2, = ax2.plot([], [],color='blue', lw=1,label='generator')
fig.tight_layout()

def init():
    line1.set_data([], [])
    line2.set_data([], [])
    return line1,line2

def update(info):
    iters.append(info[0])
    dloss.append(info[1])
    gloss.append(info[2])
    line1.set_data(iters, dloss)
    line2.set_data(iters, gloss)
    return line1, line2

ani = FuncAnimation(fig, update, frames=training,init_func=init, blit=True,interval=1,repeat=False)
plt.show()












验证DCGAN

DCGAN模型经过5000次迭代,结果如下
1.损失函数

2.生成日志展示

生成器验证

猜你喜欢

转载自blog.csdn.net/Ephemeroptera/article/details/88876098