对抗网络的简单版--手写数字MNIST的训练

版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/yangdashi888/article/details/82711853

1、对抗网络是一种数据驱动的网络。人为干预比较少。

其中生成网络的损失利用了鉴别器的损失。而鉴别器的数据数据输入利用了生成网络的生成数据跟真实数据。

两个网络再权重更新是互不干扰。都只更新自身的权重值。

下面是简单gan网络的实现代码:

import tensorflow as tf
import numpy as np
import tensorflow.examples.tutorials.mnist as input_data
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import os.path as op
import os

mnist = input_data.input_data.read_data_sets("MNIST_data/", one_hot=True)

X_dim = 784
Z_dim = 100
batch_size = 128

# 真实的输入图像的占位符
X = tf.placeholder(tf.float32, shape=[None, X_dim])
# 输入的用于生成器的输入数据的占位符
Z = tf.placeholder(tf.float32, shape=[None, Z_dim])
# 干扰样本
X_per = tf.placeholder(tf.float32, shape=[None, X_dim])
# 鉴别器使用的权重
D_W1 = tf.Variable(tf.truncated_normal([X_dim, 100], stddev=0.1), dtype=tf.float32)
D_b1 = tf.Variable(tf.constant(0.1, shape=[100]))
D_W2 = tf.Variable(tf.truncated_normal([100, 1], stddev=0.1), dtype=tf.float32)
D_b2 = tf.Variable(tf.constant(0.1, shape=[1]))
D_var_list = [D_W1, D_b1, D_W2, D_b2]
# 生成器使用的权重
G_W1 = tf.Variable(tf.truncated_normal([Z_dim, 100], stddev=0.1), dtype=tf.float32)
G_b1 = tf.Variable(tf.constant(0.1, shape=[100]))
G_W2 = tf.Variable(tf.truncated_normal([100, X_dim], stddev=0.1), dtype=tf.float32)
G_b2 = tf.Variable(tf.constant(0.1, shape=[X_dim]))
G_var_list = [G_W1, G_b1, G_W2, G_b2]


# 定义一个画生成器的生成图像的函数
def plot(samples):
    fig = plt.figure(figsize=(4, 4))
    gs = gridspec.GridSpec(4, 4)
    gs.update(wspace=0.5, hspace=0.5)
    for i, sample in enumerate(samples):
        plt.subplot(gs[i])
        plt.imshow(sample.reshape(28, 28), cmap='Greys_r')

    return fig


# 定义鉴别其的网络
def discriminator(x):
    output1 = tf.nn.relu(tf.matmul(x, D_W1) + D_b1)
    output2 = tf.matmul(output1, D_W2) + D_b2
    return output2


# 定义生成器网络
def genemator(x):
    output1 = tf.nn.relu(tf.matmul(x, G_W1) + G_b1)
    # 这个生成其需要使用到sigmoid
    output2 = tf.matmul(output1, G_W2) + G_b2
    output3 = tf.nn.sigmoid(output2)
    return output3


def get_perterbed_batch(minibatch):
    return minibatch + 0.5 * np.random.random(minibatch.shape)


# 定义损失函数
G_samples = genemator(Z)
D_Gsample_out = discriminator(G_samples)
D_X_out = discriminator(X)
D_loss_real = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=D_X_out, labels=tf.ones_like(D_X_out)))
D_loss_fake = tf.reduce_mean(
    tf.nn.sigmoid_cross_entropy_with_logits(logits=D_Gsample_out, labels=tf.zeros_like(D_Gsample_out)))
# 为训练鉴别器增加一些干扰样本的损失
alpha = tf.random_uniform(
    shape=[batch_size,1],
    minval=0.,
    maxval=1.
)
differences = X_per - X
interpolates=X+(alpha*differences)
gradients = tf.gradients(discriminator(interpolates), [interpolates])[0]
slopes=tf.sqrt(tf.reduce_sum(tf.square(gradients),reduction_indices=[1]))
gradient_penalty = tf.reduce_mean((slopes-1.)**2)


sum_D_loss = D_loss_real + D_loss_fake
G_loss = tf.reduce_mean(
    tf.nn.sigmoid_cross_entropy_with_logits(logits=D_Gsample_out, labels=tf.ones_like(D_Gsample_out)))
# 定义优化器
D_train_op = tf.train.AdamOptimizer(learning_rate=0.0001, beta1=0.5, beta2=0.9).minimize(sum_D_loss,
                                                                                         var_list=D_var_list)
G_train_op = tf.train.AdamOptimizer(learning_rate=0.0001, beta1=0.5, beta2=0.9).minimize(G_loss, var_list=G_var_list)
# 创建一个会话并初始化所有定义的变量
session = tf.Session()
session.run(tf.global_variables_initializer())

if not op.exists("outown/"):
    os.makedirs("outown/")

# 定义存储训练数据的变量
plotD = []
plotG = []
i = 0
# 进行训练迭代
for it in range(1, 200000):
    X_mb, _ = mnist.train.next_batch(batch_size)
    X_permb = get_perterbed_batch(X_mb)
    Z_mb = np.random.uniform(-1., 1., size=[batch_size, Z_dim])
    # 运行鉴别器图
    _, D_curloss = session.run([D_train_op, sum_D_loss], feed_dict={X: X_mb, Z: Z_mb, X_per: X_permb})
    # 运行生成器图
    _, G_curloss = session.run([G_train_op, G_loss], feed_dict={Z: Z_mb})
    plotG.append(G_curloss)
    plotD.append(D_curloss)
    if it % 1000 == 0:
        plt.subplot()
        plotnD = np.array(plotD)
        plt.plot(plotnD)
        plotnG = np.array(plotG)
        plt.plot(plotnG)
        plt.show()
        showG = np.random.uniform(-1., 1., size=[16, Z_dim])
        samples = session.run(G_samples, feed_dict={Z: showG})
        curfig = plot(samples)
        curfig.savefig("outown/{}.png".format(str(i).zfill(4)))
        print("iterate:{} ,D_loss{:.4},G_loss{:.4}".format(i, G_curloss, D_curloss))
        i += 1

其网络的损失曲线图如下:

其最后产生的图片如下:

看起来数字还是有模有样的。

猜你喜欢

转载自blog.csdn.net/yangdashi888/article/details/82711853