import tensorflow as tf import numpy as np import pickle import matplotlib.pyplot as plt import scipy.misc import os from tensorflow.examples.tutorials.mnist import input_data mnist = input_data.read_data_sets('MNIST_data',one_hot=True) # image = image.reshape(-1,28,28) image = mnist.train.images # for num in range(len(image)): # scipy.misc.imsave('samples'+os.sep+str(num)+'.jpg',image[num])#scipy.misc.imsave直接将矩阵转为图片存储 def get_inputs(real_size, noise_size): """ 真实图像tensor与噪声图像tensor """ real_img = tf.placeholder(tf.float32, [None, real_size], name='real_img') noise_img = tf.placeholder(tf.float32, [None, noise_size], name='noise_img') return real_img, noise_img def get_generator(noise_img, out_dim, reuse=False, alpha=0.01): """ 生成器 noise_img: 生成器的输入 out_dim: 生成器输出tensor的size,这里应该为32*32=784 alpha: leaky ReLU系数 """ with tf.variable_scope("generator", reuse=reuse): hidden1 = tf.layers.dense(noise_img, 256) hidden1 = tf.maximum(alpha * hidden1, hidden1) hidden1 = tf.layers.batch_normalization(hidden1,momentum=0.8,training=True) hidden2 = tf.layers.dense(hidden1,512) hidden2 = tf.maximum(alpha*hidden2,hidden2) hidden2 = tf.layers.batch_normalization(hidden2,momentum=0.8,training=True) hidden3 = tf.layers.dense(hidden1,1024) hidden3 = tf.maximum(alpha*hidden2,hidden2) hidden3 = tf.layers.batch_normalization(hidden3,momentum=0.8,training=True) # logits & outputs logits = tf.layers.dense(hidden3, out_dim) outputs = tf.tanh(logits) return logits, outputs def get_discriminator(img, reuse=False, alpha=0.01): """ 判别器 alpha: Leaky ReLU系数 """ with tf.variable_scope("discriminator", reuse=reuse): # hidden layer hidden1 = tf.layers.dense(img, 512) hidden1 = tf.maximum(alpha * hidden1, hidden1) hidden2 = tf.layers.dense(hidden1,256) hidden2 = tf.maximum(alpha * hidden2,hidden2) # logits & outputs logits = tf.layers.dense(hidden2, 1) outputs = tf.sigmoid(logits) return logits, outputs # 定义参数 # 真实图像的size img_size = mnist.train.images[0].shape[0] # 传入给generator的噪声size noise_size = 100 # leaky ReLU的参数 alpha = 0.2 #0.01 # learning_rate learning_rate = 0.0002 # label smoothing smooth = 0.1 tf.reset_default_graph() real_img, noise_img = get_inputs(img_size, noise_size) # generator g_logits, g_outputs = get_generator(noise_img, img_size) # discriminator d_logits_real, d_outputs_real = get_discriminator(real_img) d_logits_fake, d_outputs_fake = get_discriminator(g_outputs, reuse=True) # discriminator的loss # 识别真实图片 d_loss_real = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=d_logits_real, labels=tf.ones_like(d_logits_real)) * (1 - smooth)) # 识别生成的图片 d_loss_fake = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=d_logits_fake, labels=tf.zeros_like(d_logits_fake))) # 总体loss d_loss = tf.add(d_loss_real, d_loss_fake) # generator的loss g_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=d_logits_fake, labels=tf.ones_like(d_logits_fake)) * (1 - smooth)) train_vars = tf.trainable_variables() # generator中的tensor g_vars = [var for var in train_vars if var.name.startswith("generator")] # discriminator中的tensor d_vars = [var for var in train_vars if var.name.startswith("discriminator")] # optimizer d_train_opt = tf.train.AdamOptimizer(learning_rate).minimize(d_loss, var_list=d_vars) g_train_opt = tf.train.AdamOptimizer(learning_rate).minimize(g_loss, var_list=g_vars) # batch_size batch_size = 64 # 训练迭代轮数 epochs = 2 # 存储测试样例 samples = [] # 存储loss losses = [] # 保存生成器变量 saver = tf.train.Saver(var_list = g_vars) # saver = tf.train.Saver() # 开始训练 with tf.Session() as sess: sess.run(tf.global_variables_initializer()) for e in range(epochs): for batch_i in range(mnist.train.num_examples//batch_size): ######################################################################### # batch = mnist.train.next_batch(batch_size) # batch_images = batch[0].reshape((batch_size, 784)) ########################################################################## X_train = mnist.train.images idx = np.random.randint(0,X_train.shape[0],batch_size) batch_images = X_train[idx] # 对图像像素进行scale,这是因为tanh输出的结果介于(-1,1),real和fake图片共享discriminator的参数 batch_images = batch_images*2 - 1 # generator的输入噪声 batch_noise = np.random.uniform(-1, 1, size=(batch_size, noise_size)) # Run optimizers _ = sess.run(d_train_opt, feed_dict={real_img: batch_images, noise_img: batch_noise}) _ = sess.run(g_train_opt, feed_dict={noise_img: batch_noise}) # 每一轮结束计算loss train_loss_d = sess.run(d_loss, feed_dict = {real_img: batch_images, noise_img: batch_noise}) # real img loss train_loss_d_real = sess.run(d_loss_real, feed_dict = {real_img: batch_images, noise_img: batch_noise}) # fake img loss train_loss_d_fake = sess.run(d_loss_fake, feed_dict = {real_img: batch_images, noise_img: batch_noise}) # generator loss train_loss_g = sess.run(g_loss, feed_dict = {noise_img: batch_noise}) print("Epoch {}/{}...".format(e+1, epochs), "Discriminator Loss: {:.4f}(Real: {:.4f} + Fake: {:.4f})...".format(train_loss_d, train_loss_d_real, train_loss_d_fake), "Generator Loss: {:.4f}".format(train_loss_g)) # 记录各类loss值 losses.append((train_loss_d, train_loss_d_real, train_loss_d_fake, train_loss_g)) # 抽取样本后期进行观察 sample_noise = np.random.uniform(-1, 1, size=(100, noise_size)) gen_samples = sess.run(get_generator(noise_img,img_size, reuse=True), feed_dict={noise_img: sample_noise}) ################################################################### #画出每轮图形,每张一百个数字 gen_imgs = gen_samples[1].reshape(100,28,28,1) gen_imgs = 0.5 * gen_imgs + 0.5 r, c = 10, 10 fig, axs = plt.subplots(r, c) cnt = 0 for i in range(r): for j in range(c): axs[i,j].imshow(gen_imgs[cnt, :,:,0], cmap='gray') axs[i,j].axis('off') cnt += 1 if not os.path.exists('images'): os.mkdir('images') fig.savefig("images/%d.png" % e) plt.close() ################################################################## samples.append(gen_samples) # 存储checkpoints saver.save(sess, './checkpoints/generator.ckpt') fig, ax = plt.subplots(figsize=(20,7)) losses = np.array(losses) plt.plot(losses.T[0], label='Discriminator Total Loss') plt.plot(losses.T[1], label='Discriminator Real Loss') plt.plot(losses.T[2], label='Discriminator Fake Loss') plt.plot(losses.T[3], label='Generator') plt.title("Training Losses") plt.legend() with open('train_samples.pkl','wb') as f: pickle.dump(samples,f) ##################################################################### #生成最终的生成图片 def gen(): # 加载我们的生成器变量 saver = tf.train.Saver(var_list=g_vars) with tf.Session() as sess: saver.restore(sess, tf.train.latest_checkpoint('checkpoints')) sample_noise = np.random.uniform(-1, 1, size=(64, noise_size))#size需和批次大小一致 gen_samples = sess.run(get_generator(noise_img, img_size, reuse=True), feed_dict={noise_img: sample_noise}) #gen_samplse[1].shape=(64,784) [1]output a = gen_samples[1].reshape(64,28,28) print(a.shape) for num in range(len(a)): if not os.path.exists('samples'): os.mkdir('samples') scipy.misc.imsave('samples'+os.sep+str(num)+'.jpg',a[num]) gen() ################################################################### # 生成每轮的生成图片 def gen1(): with open('train_samples.pkl','rb') as f: samples = pickle.load(f) saver = tf.train.Saver(var_list=g_vars) with tf.Session() as sess: saver.restore(sess,tf.train.latest_checkpoint('checkpoints')) b = samples[1][1]#[19]轮数[1]output shape(n_sample,784) b = b.reshape(-1,28,28) for num in range(len(b)): scipy.misc.imsave('samplse'+os.sep+str(num)+'.jpg',b[num]) # gen1() ##################################################################### # Load samples from generator taken while training with open('train_samples.pkl', 'rb') as f: samples = pickle.load(f) def view_samples(epoch, samples): """ epoch代表第几次迭代的图像 samples为我们的采样结果 """ fig, axes = plt.subplots(figsize=(7,7), nrows=5, ncols=5, sharey=True, sharex=True) for ax, img in zip(axes.flatten(), samples[epoch][1]): # 这里samples[epoch][1]代表生成的图像结果,而[0]代表对应的logits ax.xaxis.set_visible(False) ax.yaxis.set_visible(False) im = ax.imshow(img.reshape((28,28)), cmap='Greys_r') return fig, axes # _ = view_samples(-1, samples) # 显示最后一轮的outputs ######################################################################## # 指定要查看的轮次 # epoch_idx = [0, 5, 10, 20, 40, 60, 80, 100, 150, 250] # 一共300轮,不要越界 # show_imgs = [] # for i in epoch_idx: # show_imgs.append(samples[i][1]) # # 指定图片形状 # rows, cols = 10, 25 # fig, axes = plt.subplots(figsize=(30,12), nrows=rows, ncols=cols, sharex=True, sharey=True) # # idx = range(0, epochs, int(epochs/rows)) # # for sample, ax_row in zip(show_imgs, axes): # for img, ax in zip(sample[::int(len(sample)/cols)], ax_row): # ax.imshow(img.reshape((28,28)), cmap='Greys_r') # ax.xaxis.set_visible(False) # ax.yaxis.set_visible(False) ####################################################################### # 加载我们的生成器变量 # saver = tf.train.Saver(var_list=g_vars) # with tf.Session() as sess: # saver.restore(sess, tf.train.latest_checkpoint('checkpoints')) # sample_noise = np.random.uniform(-1, 1, size=(25, noise_size)) # gen_samples = sess.run(get_generator(noise_img, img_size, reuse=True), # feed_dict={noise_img: sample_noise}) # _ = view_samples(0, [gen_samples]) #####################################################################
GAN生成MNIST手写数字
猜你喜欢
转载自blog.csdn.net/qq_38826019/article/details/80616202
今日推荐
周排行