GAN生成MNIST手写数字

import tensorflow as tf
import numpy as np
import pickle
import matplotlib.pyplot as plt
import scipy.misc
import os
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets('MNIST_data',one_hot=True)
# image = image.reshape(-1,28,28)
image = mnist.train.images
# for num in range(len(image)):
#     scipy.misc.imsave('samples'+os.sep+str(num)+'.jpg',image[num])#scipy.misc.imsave直接将矩阵转为图片存储
def get_inputs(real_size, noise_size):
    """
    真实图像tensor与噪声图像tensor
    """
    real_img = tf.placeholder(tf.float32, [None, real_size], name='real_img')
    noise_img = tf.placeholder(tf.float32, [None, noise_size], name='noise_img')
     
    return real_img, noise_img
def get_generator(noise_img, out_dim, reuse=False, alpha=0.01):
    """
    生成器
     
    noise_img: 生成器的输入
    out_dim: 生成器输出tensor的size,这里应该为32*32=784
    alpha: leaky ReLU系数
    """
    with tf.variable_scope("generator", reuse=reuse):
        
        hidden1 = tf.layers.dense(noise_img, 256)
        hidden1 = tf.maximum(alpha * hidden1, hidden1)
        hidden1 = tf.layers.batch_normalization(hidden1,momentum=0.8,training=True)
        
        hidden2 = tf.layers.dense(hidden1,512)
        hidden2 = tf.maximum(alpha*hidden2,hidden2)
        hidden2 = tf.layers.batch_normalization(hidden2,momentum=0.8,training=True)
        
        hidden3 = tf.layers.dense(hidden1,1024)
        hidden3 = tf.maximum(alpha*hidden2,hidden2)
        hidden3 = tf.layers.batch_normalization(hidden3,momentum=0.8,training=True)
        
        # logits & outputs
        logits = tf.layers.dense(hidden3, out_dim)
        outputs = tf.tanh(logits)
         
        return logits, outputs
def get_discriminator(img,  reuse=False, alpha=0.01):
    """
    判别器
    alpha: Leaky ReLU系数
    """
     
    with tf.variable_scope("discriminator", reuse=reuse):
        # hidden layer
        hidden1 = tf.layers.dense(img, 512)
        hidden1 = tf.maximum(alpha * hidden1, hidden1)
        
        hidden2 = tf.layers.dense(hidden1,256)
        hidden2 = tf.maximum(alpha * hidden2,hidden2)
        
        # logits & outputs
        logits = tf.layers.dense(hidden2, 1)
        outputs = tf.sigmoid(logits)
         
        return logits, outputs
# 定义参数
# 真实图像的size
img_size = mnist.train.images[0].shape[0]
# 传入给generator的噪声size
noise_size = 100
# leaky ReLU的参数
alpha = 0.2 #0.01
# learning_rate
learning_rate = 0.0002
# label smoothing
smooth = 0.1
tf.reset_default_graph()
 
real_img, noise_img = get_inputs(img_size, noise_size)
 
# generator
g_logits, g_outputs = get_generator(noise_img, img_size)
 
# discriminator
d_logits_real, d_outputs_real = get_discriminator(real_img)
d_logits_fake, d_outputs_fake = get_discriminator(g_outputs, reuse=True)
# discriminator的loss
# 识别真实图片
d_loss_real = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=d_logits_real, 
                                                                     labels=tf.ones_like(d_logits_real)) * (1 - smooth))
# 识别生成的图片
d_loss_fake = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=d_logits_fake, 
                                                                     labels=tf.zeros_like(d_logits_fake)))
# 总体loss
d_loss = tf.add(d_loss_real, d_loss_fake)
 
# generator的loss
g_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=d_logits_fake,
                                                                labels=tf.ones_like(d_logits_fake)) * (1 - smooth))
train_vars = tf.trainable_variables()
 
# generator中的tensor
g_vars = [var for var in train_vars if var.name.startswith("generator")]
# discriminator中的tensor
d_vars = [var for var in train_vars if var.name.startswith("discriminator")]
 
# optimizer
d_train_opt = tf.train.AdamOptimizer(learning_rate).minimize(d_loss, var_list=d_vars)
g_train_opt = tf.train.AdamOptimizer(learning_rate).minimize(g_loss, var_list=g_vars)
# batch_size
batch_size = 64
# 训练迭代轮数
epochs = 2
# 存储测试样例
samples = []
# 存储loss
losses = []
# 保存生成器变量
saver = tf.train.Saver(var_list = g_vars)
# saver = tf.train.Saver()
# 开始训练
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    for e in range(epochs):
        for batch_i in range(mnist.train.num_examples//batch_size):
#########################################################################
#             batch = mnist.train.next_batch(batch_size)
#             batch_images = batch[0].reshape((batch_size, 784))
##########################################################################
            X_train = mnist.train.images
            idx = np.random.randint(0,X_train.shape[0],batch_size)
            batch_images = X_train[idx]
            # 对图像像素进行scale,这是因为tanh输出的结果介于(-1,1),real和fake图片共享discriminator的参数
            batch_images = batch_images*2 - 1
                  
            # generator的输入噪声
            batch_noise = np.random.uniform(-1, 1, size=(batch_size, noise_size))
                  
            # Run optimizers
            _ = sess.run(d_train_opt, feed_dict={real_img: batch_images, noise_img: batch_noise})
            _ = sess.run(g_train_opt, feed_dict={noise_img: batch_noise})
              
        # 每一轮结束计算loss
        train_loss_d = sess.run(d_loss, 
                                feed_dict = {real_img: batch_images, 
                                             noise_img: batch_noise})
        # real img loss
        train_loss_d_real = sess.run(d_loss_real, 
                                     feed_dict = {real_img: batch_images, 
                                                 noise_img: batch_noise})
        # fake img loss
        train_loss_d_fake = sess.run(d_loss_fake, 
                                    feed_dict = {real_img: batch_images, 
                                                 noise_img: batch_noise})
        # generator loss
        train_loss_g = sess.run(g_loss, 
                                feed_dict = {noise_img: batch_noise})
              
                  
        print("Epoch {}/{}...".format(e+1, epochs),
              "Discriminator Loss: {:.4f}(Real: {:.4f} + Fake: {:.4f})...".format(train_loss_d, train_loss_d_real, train_loss_d_fake),
              "Generator Loss: {:.4f}".format(train_loss_g))    
        # 记录各类loss值
        losses.append((train_loss_d, train_loss_d_real, train_loss_d_fake, train_loss_g))
              
        # 抽取样本后期进行观察
        sample_noise = np.random.uniform(-1, 1, size=(100, noise_size))
        gen_samples = sess.run(get_generator(noise_img,img_size, reuse=True),
                               feed_dict={noise_img: sample_noise})
###################################################################
#画出每轮图形,每张一百个数字
        gen_imgs = gen_samples[1].reshape(100,28,28,1)
        gen_imgs = 0.5 * gen_imgs + 0.5
        r, c = 10, 10
        fig, axs = plt.subplots(r, c)
        cnt = 0
        for i in range(r):
            for j in range(c):
                axs[i,j].imshow(gen_imgs[cnt, :,:,0], cmap='gray')
                axs[i,j].axis('off')
                cnt += 1
        if not os.path.exists('images'):
            os.mkdir('images')
        fig.savefig("images/%d.png" % e)
        plt.close()
##################################################################
        samples.append(gen_samples)
        # 存储checkpoints
        saver.save(sess, './checkpoints/generator.ckpt')
fig, ax = plt.subplots(figsize=(20,7))
losses = np.array(losses)
plt.plot(losses.T[0], label='Discriminator Total Loss')
plt.plot(losses.T[1], label='Discriminator Real Loss')
plt.plot(losses.T[2], label='Discriminator Fake Loss')
plt.plot(losses.T[3], label='Generator')
plt.title("Training Losses")
plt.legend()
 
with open('train_samples.pkl','wb') as f:
    pickle.dump(samples,f)

#####################################################################
#生成最终的生成图片
def gen():
    # 加载我们的生成器变量
    saver = tf.train.Saver(var_list=g_vars)
    with tf.Session() as sess:
        saver.restore(sess, tf.train.latest_checkpoint('checkpoints'))
        sample_noise = np.random.uniform(-1, 1, size=(64, noise_size))#size需和批次大小一致
        gen_samples = sess.run(get_generator(noise_img,  img_size, reuse=True),
                               feed_dict={noise_img: sample_noise})
        #gen_samplse[1].shape=(64,784)    [1]output
        a = gen_samples[1].reshape(64,28,28)
        print(a.shape)
        for num in range(len(a)):
            if not os.path.exists('samples'):
                os.mkdir('samples')
            scipy.misc.imsave('samples'+os.sep+str(num)+'.jpg',a[num])
gen()
###################################################################
# 生成每轮的生成图片
def gen1():
    with open('train_samples.pkl','rb') as f:
        samples = pickle.load(f)
        saver = tf.train.Saver(var_list=g_vars)
        with tf.Session() as sess:
            saver.restore(sess,tf.train.latest_checkpoint('checkpoints'))
            b = samples[1][1]#[19]轮数[1]output shape(n_sample,784)
            b = b.reshape(-1,28,28)
            for num in range(len(b)):
                scipy.misc.imsave('samplse'+os.sep+str(num)+'.jpg',b[num])
# gen1()
#####################################################################
# Load samples from generator taken while training
with open('train_samples.pkl', 'rb') as f:
    samples = pickle.load(f)
def view_samples(epoch, samples):
    """
    epoch代表第几次迭代的图像
    samples为我们的采样结果
    """
    fig, axes = plt.subplots(figsize=(7,7), nrows=5, ncols=5, sharey=True, sharex=True)
    for ax, img in zip(axes.flatten(), samples[epoch][1]): # 这里samples[epoch][1]代表生成的图像结果,而[0]代表对应的logits
        ax.xaxis.set_visible(False)
        ax.yaxis.set_visible(False)
        im = ax.imshow(img.reshape((28,28)), cmap='Greys_r')
    
    return fig, axes
# _ = view_samples(-1, samples) # 显示最后一轮的outputs
########################################################################
# 指定要查看的轮次
# epoch_idx = [0, 5, 10, 20, 40, 60, 80, 100, 150, 250] # 一共300轮,不要越界
# show_imgs = []
# for i in epoch_idx:
#     show_imgs.append(samples[i][1])
# # 指定图片形状
# rows, cols = 10, 25
# fig, axes = plt.subplots(figsize=(30,12), nrows=rows, ncols=cols, sharex=True, sharey=True)
# 
# idx = range(0, epochs, int(epochs/rows))
# 
# for sample, ax_row in zip(show_imgs, axes):
#     for img, ax in zip(sample[::int(len(sample)/cols)], ax_row):
#         ax.imshow(img.reshape((28,28)), cmap='Greys_r')
#         ax.xaxis.set_visible(False)
#         ax.yaxis.set_visible(False)
#######################################################################
# 加载我们的生成器变量
# saver = tf.train.Saver(var_list=g_vars)
# with tf.Session() as sess:
#     saver.restore(sess, tf.train.latest_checkpoint('checkpoints'))
#     sample_noise = np.random.uniform(-1, 1, size=(25, noise_size))
#     gen_samples = sess.run(get_generator(noise_img, img_size, reuse=True),
#                            feed_dict={noise_img: sample_noise})
# _ = view_samples(0, [gen_samples])
#####################################################################

    

猜你喜欢

转载自blog.csdn.net/qq_38826019/article/details/80616202