生死看淡,不服就GAN(四)---- 用全连层GAN生成MNIST手写体

搭建全连接GAN网络

#*************************************** 生死看淡,不服就GAN **************************************************************
"""
PROJECT:MNIST_GAN_MLP
Author:Ephemeroptera
Date:2018-4-24
QQ:605686962
Reference:' improved_wgan_training-master': <https://github.com/igul222/improved_wgan_training>
           'Zardinality/WGAN-tensorflow':<https://github.com/Zardinality/WGAN-tensorflow>
           'NELSONZHAO/zhihu':<https://github.com/NELSONZHAO/zhihu>
"""

# import dependency
import tensorflow as tf
import numpy as np
import pickle
import visualization
import os
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation
from threading import Thread
from time import sleep
import time
import cv2

# import MNIST dataset
mnist_dir = r'../mnist_dataset'
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets(mnist_dir)

#------------------------------------------------ define moudle related -------------------------------------------------

# define generator
def Generator_MLP(latents,out_dim,reuse=False):
    uints = 128

    with tf.variable_scope("generator", reuse=reuse):
        # dense0
        dense0 = tf.layers.dense(latents,uints,activation=tf.nn.leaky_relu,name='dense0')
        # dropout
        dropout = tf.layers.dropout(dense0,rate=0.2,name='dropout')
        # dense1
        logits = tf.layers.dense(dropout, out_dim,name='dense1')
        # output
        outputs = tf.tanh(logits,name='outputs')

        return logits,outputs

# define discriminator
def Discriminator_MLP(input,out_dim,reuse=False):
    uints = 128

    with tf.variable_scope("discriminator", reuse=reuse):
        # dense0
        dense0 = tf.layers.dense(input, uints, activation=tf.nn.leaky_relu, name='dense0',
                                 kernel_initializer=tf.random_normal_initializer(0,0.1))
        # dense1
        logits = tf.layers.dense(dense0, out_dim, name='dense1',
                                 kernel_initializer=tf.random_normal_initializer(0,0.1))
        # output
        outputs = tf.sigmoid(logits, name='outputs')

        return logits, outputs

# counting total to vars
def COUNT_VARS(vars):
    total_para = 0
    for variable in vars:
        # get each shape of vars
        shape = variable.get_shape()
        variable_para = 1
        for dim in shape:
            variable_para *= dim.value
        total_para += variable_para
    return total_para

# display paras infomation
def ShowParasList(paras):
    p = open('./trainLog/Paras.txt', 'w')
    p.writelines(['vars_total: %d'%COUNT_VARS(paras),'\n'])
    for variable in paras:
        p.writelines([variable.name, str(variable.get_shape()),'\n'])
        print(variable.name, variable.get_shape())
    p.close()

# build related dirs
def GEN_DIR():
    if not os.path.isdir('ckpt'):
        print('DIR:ckpt NOT FOUND,BUILDING ON CURRENT PATH..')
        os.mkdir('ckpt')
    if not os.path.isdir('trainLog'):
        print('DIR:ckpt NOT FOUND,BUILDING ON CURRENT PATH..')
        os.mkdir('trainLog')

#---------------------------------------------- build graph -------------------------------------------------------------
# hyper-parameters
latents_dim = 128
img_dim = 28*28
smooth = 0.1
learn_rate = 0.001

# define input
latents = tf.placeholder(shape=[None,latents_dim],dtype=tf.float32,name='latents')
input_real = tf.placeholder(shape=[None,img_dim],dtype=tf.float32,name='input_real')

# get output of G,D
_, g_outputs = Generator_MLP(latents,img_dim,reuse=False)
d_logits_real, d_outputs_real = Discriminator_MLP(input_real,1,reuse=False)
d_logits_fake, d_outputs_fake = Discriminator_MLP(g_outputs,1,reuse=True)

# define loss
d_loss_real = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=d_logits_real,
                                                                     labels=tf.ones_like(d_logits_real)) * (1 - smooth))
d_loss_fake = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=d_logits_fake,
                                                                         labels=tf.zeros_like(d_logits_fake)))
d_loss = tf.add(d_loss_real, d_loss_fake)
g_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=d_logits_fake,
                                                                    labels=tf.ones_like(d_logits_fake)) * (1 - smooth))
# gradient descent
train_vars = tf.trainable_variables()
ShowParasList(train_vars) # display
g_vars = [var for var in train_vars if var.name.startswith("generator")]
d_vars = [var for var in train_vars if var.name.startswith("discriminator")]
d_train_opt = tf.train.AdamOptimizer(learn_rate).minimize(d_loss, var_list=d_vars)
g_train_opt = tf.train.AdamOptimizer(learn_rate).minimize(g_loss, var_list=g_vars)

#------------------------------------------------ iterations --------------------------====----------------------------

GEN_DIR()
max_iters = 20000
batch_size = 64
critic_n = 5
GenLog = []
Losses = []
saver = tf.train.Saver(var_list=g_vars)

# recording training info
def SavingRecords():
    global Losses
    global GenLog
    # saving Losses
    with open('./trainLog/loss_variation.loss', 'wb') as l:
        losses = np.array(Losses)
        pickle.dump(losses, l)
        print('saving Losses sucessfully!')
    # saving 生成样本
    with open('./trainLog/GenLog.log', 'wb') as g:
        GenLog = np.array(GenLog)
        pickle.dump(GenLog, g)
        print('saving GenLog sucessfully!')

# define training
def training():
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())

        time_start = time.time()  # go
        for steps in range(max_iters+1):

            # 获取数据集
            data_batch = mnist.train.next_batch(batch_size)[0]
            # ops.SHOW('real',data_batch[0].reshape([28,28,1]))
            data_batch = data_batch * 2 - 1
            data_batch = data_batch.astype(np.float32)
            z = np.random.normal(0, 1, size=[batch_size, latents_dim]).astype(np.float32)

            # 训练discriminator
            for n in range(critic_n):
                sess.run(d_train_opt, feed_dict={
    
    input_real: data_batch, latents: z})
            # 训练Generator
            sess.run(g_train_opt, feed_dict={
    
    latents: z})

            # recording training_losses
            train_loss_d = sess.run(d_loss, feed_dict={
    
    input_real: data_batch, latents: z})
            train_loss_g = sess.run(g_loss, feed_dict={
    
    latents: z})
            info = [steps, train_loss_d, train_loss_g]

            # recording training_products
            gen_sanmpes = sess.run(g_outputs, feed_dict={
    
    latents: z})
            visualization.CV2_BATCH_SHOW((np.reshape(gen_sanmpes[0:9], [-1, 28, 28, 1]) + 1) / 2, 1, 3, 3, delay=1)
            print('iters::%d/%d..Discriminator_loss:%.3f..Generator_loss:%.3f..' % (steps, max_iters, train_loss_d, train_loss_g))

            if steps % 5 == 0:
                Losses.append(info)
                GenLog.append(gen_sanmpes)

            if steps % 1000 == 0 and steps > 0:
                saver.save(sess, './ckpt/generator.ckpt', global_step=steps)

            if steps == max_iters:
                # cv2.destroyAllWindows()
                # setup a thread to saving the training info
                sleep(3)
                thread1 = Thread(target=SavingRecords,args=())
                thread1.start()

            yield info

#------------------------------------------------- ANIMATION ----------------------------------------------------------
# ANIMATION
"""
note: in this code , we will see the runtime-variation of G,D losses
"""
iters = []
dloss = []
gloss = []
fig = plt.figure()
ax1 = fig.add_subplot(2,1,1,xlim=(0, max_iters), ylim=(-1, 1))
ax2 = fig.add_subplot(2,1,2,xlim=(0, max_iters), ylim=(-20, 20))
ax1.set_title('discriminator_loss')
ax2.set_title('generator_loss')
line1, = ax1.plot([], [], color='red',lw=1,label='discriminator')
line2, = ax2.plot([], [],color='blue', lw=1,label='generator')
fig.tight_layout()

def init():
    line1.set_data([], [])
    line2.set_data([], [])
    return line1,line2

def update(info):
    iters.append(info[0])
    dloss.append(info[1])
    gloss.append(info[2])
    line1.set_data(iters, dloss)
    line2.set_data(iters, gloss)
    return line1, line2

ani = FuncAnimation(fig, update, frames=training,init_func=init, blit=True,interval=1,repeat=False)
plt.show()




实验结果

1.损失函数变化曲线

2.生成日志

3.验证生成器

猜你喜欢

转载自blog.csdn.net/Ephemeroptera/article/details/88829294