生成对抗网络的TensorFlow初探

版权声明:转载请联系作者,并注明出处 https://blog.csdn.net/onehao/article/details/89398590

生成对抗网络的TensorFlow初探

原创: 比昂 比昂日记 3月28日

之前介绍过生成对抗网络的初步原理,参见(生成对抗网络浅析(GAN))。

今天结合最近很火的TensorFlow,看看原理背后的实现。

01

模型

上一篇,参见(生成对抗网络浅析(GAN))定义了GAN模型的Model,

使用TFGAN我们组要定义4个重要属性

a. Generator,  在噪声的干扰下,生成Fake image;

b. Discriminator, 判定输入Training set,是Real,还是Fake;

c. 真实图片,Real Images;

d. Random noise;

Generator

  •  
  •  
  •  
  •  
  •  
  •  
  •  
  •  
  •  
  •  
  •  
  •  
  •  
  •  
  •  
  •  
  •  
  •  
  •  
  •  
  •  
  •  
  •  
  •  
  •  
  •  
  •  
def generator_fn(noise, weight_decay=2.5e-5, is_training=True):    """G  生成MNIST图片的G网络.
    Args:        noise: Tensor表征的噪音。        weight_decay: L2正则化 -- light weight decay。        is_training: 如果为“True”,批量规范使用批量统计。如果是'False`,批量规范使用从人口中收集的指数移动平均线统计.
    Returns:        生成图像范围[-1, 1].    """    with framework.arg_scope(            [layers.fully_connected, layers.conv2d_transpose],            activation_fn=tf.nn.relu, normalizer_fn=layers.batch_norm,            weights_regularizer=layers.l2_regularizer(weight_decay)), \         framework.arg_scope([layers.batch_norm], is_training=is_training,                             zero_debias_moving_mean=True):        net = layers.fully_connected(noise, 1024)        net = layers.fully_connected(net, 7 * 7 * 256)        net = tf.reshape(net, [-1, 7, 7, 256])        net = layers.conv2d_transpose(net, 64, [4, 4], stride=2)        net = layers.conv2d_transpose(net, 32, [4, 4], stride=2)        # Make sure that generator output is in the same range as `inputs`        # ie [-1, 1].        net = layers.conv2d(net, 1, 4, normalizer_fn=None, activation_fn=tf.tanh)
        return net

Discriminator

  •  
  •  
  •  
  •  
  •  
  •  
  •  
  •  
  •  
  •  
  •  
  •  
  •  
  •  
  •  
  •  
  •  
  •  
  •  
  •  
  •  
  •  
  •  
  •  
def discriminator_fn(img, unused_conditioning, weight_decay=2.5e-5,                     is_training=True):    """D  使用MNIST数字的D网络.
    Args:        img: 真实或生成的图片,范围 [-1, 1].        unused_conditioning: TFGAN API可以帮助处理条件GAN,这需要向生成器和鉴别器提供额外的“条件”信息。由于此示例不是有条件的,因此我们不使用此参数。        weight_decay: L2 正则化 weight decay。        is_training: 同G网络。
    Returns:        记录图像真实概率。    """    with framework.arg_scope(            [layers.conv2d, layers.fully_connected],            activation_fn=leaky_relu, normalizer_fn=None,            weights_regularizer=layers.l2_regularizer(weight_decay),            biases_regularizer=layers.l2_regularizer(weight_decay)):        net = layers.conv2d(img, 64, [4, 4], stride=2)        net = layers.conv2d(net, 128, [4, 4], stride=2)        net = layers.flatten(net)        with framework.arg_scope([layers.batch_norm], is_training=is_training):            net = layers.fully_connected(net, 1024, normalizer_fn=layers.batch_norm)        return layers.linear(net, 1)

Real images, 使用mnist数据源作为real images输入。

  •  
  •  
  •  
with tf.device('/cpu:0'):    real_images, _, _ = data_provider.provide_data(        'train', batch_size, MNIST_DATA_DIR)

GANModel Tuple

  •  
  •  
  •  
  •  
  •  
gan_model = tfgan.gan_model(    generator_fn,    discriminator_fn,    real_data=real_images,    generator_inputs=tf.random_normal([batch_size, noise_dims]))

02

损失函数

损失函数(loss function)是用来估量模型的预测值f(x)与真实值Y的不一致程度。

其中,前面的均值函数表示的是经验风险函数,L代表的是损失函数,后面的Φ是正则化项(regularizer)或者叫惩罚项(penalty term),它可以是L1,也可以是L2,或者其他的正则函数。整个式子表示的意思是找到使目标函数最小时的θ值。

对于GAN, 论文中的的损失函数就是二元极大极小 -- minmax, 

使用TF中的minmax损失函数

  •  
  •  
  •  
  •  
  •  
# 使用原始论问中的minmax损失函数。vanilla_gan_loss = tfgan.gan_loss(    gan_model,    generator_loss_fn=tfgan.losses.minimax_generator_loss,    discriminator_loss_fn=tfgan.losses.minimax_discriminator_loss)

同样也可以使用Wasserstein、Improved Wasserstein, 可参见论文https://arxiv.org/pdf/1701.07875.pdf

  •  
  •  
  •  
  •  
  •  
  •  
  •  
  •  
  •  
# 使用 Wasserstein loss , 参考(https://arxiv.org/abs/1701.07875) # (https://arxiv.org/abs/1704.00028).improved_wgan_loss = tfgan.gan_loss(    gan_model,    # We make the loss explicit for demonstration, even though the default is    # Wasserstein loss.    generator_loss_fn=tfgan.losses.wasserstein_generator_loss,    discriminator_loss_fn=tfgan.losses.wasserstein_discriminator_loss,    gradient_penalty_weight=1.0)

参考TF的实现

  •  
  •  
  •  
  •  
  •  
  •  
  •  
  •  
  •  
  •  
  •  
  •  
  •  
  •  
  •  
  •  
  •  
  •  
  •  
  •  
  •  
  •  
  •  
  •  
  •  
  •  
  •  
  •  
  •  
  •  
  •  
  •  
  •  
  •  
  •  
  •  
# Wasserstein losses from `Wasserstein GAN` (https://arxiv.org/abs/1701.07875).def wasserstein_generator_loss(    discriminator_gen_outputs,    weights=1.0,    scope=None,    loss_collection=ops.GraphKeys.LOSSES,    reduction=losses.Reduction.SUM_BY_NONZERO_WEIGHTS,    add_summaries=False):  """Wasserstein generator loss for GANs.  See `Wasserstein GAN` (https://arxiv.org/abs/1701.07875) for more details.  Args:    discriminator_gen_outputs: Discriminator output on generated data. Expected      to be in the range of (-inf, inf).    weights: Optional `Tensor` whose rank is either 0, or the same rank as      `discriminator_gen_outputs`, and must be broadcastable to      `discriminator_gen_outputs` (i.e., all dimensions must be either `1`, or      the same as the corresponding dimension).    scope: The scope for the operations performed in computing the loss.    loss_collection: collection to which this loss will be added.    reduction: A `tf.losses.Reduction` to apply to loss.    add_summaries: Whether or not to add detailed summaries for the loss.  Returns:    A loss Tensor. The shape depends on `reduction`.  """  with ops.name_scope(scope, 'generator_wasserstein_loss', (      discriminator_gen_outputs, weights)) as scope:    discriminator_gen_outputs = _to_float(discriminator_gen_outputs)
    loss = - discriminator_gen_outputs    loss = losses.compute_weighted_loss(        loss, weights, scope, loss_collection, reduction)
    if add_summaries:      summary.scalar('generator_wass_loss', loss)
  return loss

自定义损失函数。

  •  
  •  
  •  
  •  
  •  
def silly_custom_generator_loss(gan_model, add_summaries=False):    return tf.reduce_mean(gan_model.discriminator_gen_outputs)def silly_custom_discriminator_loss(gan_model, add_summaries=False):    return (tf.reduce_mean(gan_model.discriminator_gen_outputs) -            tf.reduce_mean(gan_model.discriminator_real_outputs))

03

训练&评估

训练

GAN的训练过程中,需要交替训练Generator和Discriminator网络,让Generator和Discriminator处于不断的优化和对抗中,正如论文算法的过程

过程相对比较简单,首先定义GANTrainOps的元组,然后设置优化参数

  •  
  •  
  •  
  •  
  •  
  •  
  •  
generator_optimizer = tf.train.AdamOptimizer(0.001, beta1=0.5)discriminator_optimizer = tf.train.AdamOptimizer(0.0001, beta1=0.5)gan_train_ops = tfgan.gan_train_ops(    gan_model,    improved_wgan_loss,    generator_optimizer,    discriminator_optimizer)

评估

使用‘Inception Score’和’Frechet Inception distance‘, 来衡量生成image的分布和真实image的分布的近似情况。

  •  
  •  
  •  
  •  
  •  
  •  
  •  
  •  
  •  
  •  
  •  
  •  
  •  
  •  
  •  
  •  
  •  
  •  
  •  
  •  
  •  
  •  
num_images_to_eval = 500MNIST_CLASSIFIER_FROZEN_GRAPH = './mnist/data/classify_mnist_graph_def.pb'
# 要加载变量,请使用与训练job相同的变量范围。with tf.variable_scope('Generator', reuse=True):    eval_images = gan_model.generator_fn(        tf.random_normal([num_images_to_eval, noise_dims]),        is_training=False)
# 计算 Inception score.eval_score = util.mnist_score(eval_images, MNIST_CLASSIFIER_FROZEN_GRAPH)
# 计算 Frechet Inception distance.with tf.device('/cpu:0'):    real_images, _, _ = data_provider.provide_data(        'train', num_images_to_eval, MNIST_DATA_DIR)frechet_distance = util.mnist_frechet_distance(    real_images, eval_images, MNIST_CLASSIFIER_FROZEN_GRAPH)
# 重绘eval图片generated_data_to_visualize = tfgan.eval.image_reshaper(    eval_images[:20,...], num_cols=10)

训练过程和结果

TFGAN使用源于GAN minmax博弈的交替训练思路,可以更改G和D的更新比率。

  •  
  •  
  •  
  •  
  •  
  •  
  •  
  •  
  •  
  •  
  •  
  •  
  •  
  •  
  •  
  •  
  •  
  •  
  •  
  •  
train_step_fn = tfgan.get_sequential_train_steps()
global_step = tf.train.get_or_create_global_step()loss_values, mnist_scores, frechet_distances  = [], [], []
with tf.train.SingularMonitoredSession() as sess:    start_time = time.time()    for i in xrange(1601):        cur_loss, _ = train_step_fn(            sess, gan_train_ops, global_step, train_step_kwargs={})        loss_values.append((i, cur_loss))        if i % 200 == 0:            mnist_score, f_distance, digits_np = sess.run(                [eval_score, frechet_distance, generated_data_to_visualize])            mnist_scores.append((i, mnist_score))            frechet_distances.append((i, f_distance))            print('Current loss: %f' % cur_loss)            print('Current MNIST score: %f' % mnist_scores[-1][1])            print('Current Frechet distance: %f' % frechet_distances[-1][1])            visualize_training_generator(i, start_time, digits_np)

可以看到如论文中的演进曲线的变化(生成对抗网络浅析(GAN)

时间维度的评估指标变化如下

扩展阅读

生成对抗网络浅析(GAN)

参考:

https://github.com/tensorflow/models/tree/master/research/gan

https://arxiv.org/pdf/1701.07875.pdf

https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/gan/python/losses/python/losses_impl.py

https://blog.csdn.net/stalbo/article/details/79356739

https://zhuanlan.zhihu.com/p/44407513

http://www.csuldw.com/2016/03/26/2016-03-26-loss-function/

https://arxiv.org/pdf/1606.03498.pdf

THE END

- 晚安 -

图片长按2秒,识别图中二维码,关注订阅号

微信扫一扫
关注该公众号

猜你喜欢

转载自blog.csdn.net/onehao/article/details/89398590