原始GAN论文笔记及TensorFlow实现

版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/littlehaes/article/details/81265872

Welcome To My Blog

引言

  • 在GAN诞生之前,比起生成模型而言,判别模型更受关注,比如Alex Net,VGG,Google Net,因为典型的生成模型往往具有原理复杂,推导复杂,实现复杂的特点
  • 对于生成模型而言,通常有两种建模方式
    • 最常见的是对目标对象的概率分布建模,将其表达成具体的某种参数形式,再通过最大似然一类的方法训练模型,如深度玻尔兹曼机DBM,这样做的缺点:通常得到的似然函数无法直接求解,需要借助近似算法或者采样算法
    • 采用非参数的方式建模,如GSN,方法核心:假设一条马尔科夫链的稳态分布是数据的真实分布,然后将马尔科夫链的求解操作替换为可以用梯度反向传播来执行的操作
  • GAN作为一种训练框架,由两个网络Generator和Discriminator构成,D采用判别式准则辅助训练生成模型G,结构如下,X是真实数据,Z是随机噪声,Z经过Generator后成为X’;X’和X作为Discriminator的输入,Discriminator根据X判断X’是不是真实的数据,并将结果反馈给Generator.GAN目的就是希望X’尽可能地接近X,也就是P_g = P_data
    2.png

GAN的两个网络

Generator

G本来就是做生成的,比如Auto Encoder就是一种生成模型,GAN为什么要增加D呢?因为只用G有缺陷,以AE为例,AE侧重于生成与原图片尽可能相似的图片,但这样会牺牲掉图片中各个component之间的联系,如下图所示
1.png
对于AE来说,output1更像原图,但是我们写数字时,笔画的中间往往不会有空缺,也就是说,虽然output2最后的笔画拉长了,但比起output1来说更自然,因为output2更care各个component之间的联系.
当然了,AE可以通过增加神经网络的层数使得网络可以考虑这种联系,但是生成相同质量图片的情况下,GAN的结构更加简单.

Discriminator

D虽然是判别模型,但也可以做生成,需要解下面这个式子
3.png
也就是对于给定的输入x,遍历所有可能的数据,挑出分数最高的图片作为生成结果.但是首先需要假设D(x)的形式,如果假设D(x)是线性的,那么模型的能力太弱;如果假设是非线性的,又不好解argmax.
在GAN中,D的输入是G的输出,G的输出是一张完整的图片,D对一张完整的图片进行判别可以很好地catch到各个component之间的联系,然后将这个信息反馈给G,从而使G生成具有大局观的图片

GAN的数学推导

IanGoodfellow的论文Generative Adversarial Nets是这样引出GAN的目标函数的
对于Discriminator来说,它用来判断输入的数据是真还是假,具体做法是:对真实的数据赋予高分,对虚假的数据赋予低分;也就是希望赋予D(X)高分,赋予D(X’)低分,可以写成如下的形式
4.png
+ 取1-D(X’)是为了满足对数的定义域要求
+ 取对数,个人认为是为了凑alog(x)+blog(1-x)的形式,之后会提到
+ 取期望是把分布P_data和P_g考虑进来
对于Generator来说,希望自己生成的数据X’更接近真实数据X,也就是希望D(X’)越大越好,这便体现了G与D的博弈思想,结合G与D的初衷可得目标函数为:
5.png

目标函数的有效性

优化V(D,G)后,等价于实现了P_g = P_data,下面说明原因:

固定G,优化D

首先直接使用一个概率论中的定理:
6.png
将V(D,G)展开
7.png
最后一步合并了两个积分,从而扩大了积分限,两个被积函数在无定义处取0即可
刚才提到为什么目标函数采用对数形式,原因如下
8.png
目标函数正好符合上述定理形式,所以固定G,优化D时D的最优值为:
9.png

固定D,优化G

10.png
当P_g = P_data时,上面的不等式取等号,C(G)取得最小值,说明按照上面的方式优化目标函数,效果相当于P_g = P_data,说明了GAN的可行性

优化流程

11.png
优化D的时候优化了k次,不过论文中实验的时候取k=1
在优化G的初期,由于G生成的数据X’很假,所以log(1-D(G(z)))的梯度接近1,有点小,不利于迭代,所以会使用max_G log(D(G(z)))优化G

TensorFlow实现

完整代码可以参考深度学习-GAN专题代码复现中的”GAN的诞生”.
如果对logistic regression和交叉熵有一定的认识会对理解代码实现有很大帮助
1. 关于交叉熵,可以参考交叉熵与KL散度
2. 关于logistic regression,可以参考Logistic Regression逻辑斯蒂回归
3. TF文档中关于logistic loss的解释
12.png

# 输入噪声从正态分布中采样得到
def xavier_init(size):
    in_dim = size[0]
    xavier_stddev = 1. / tf.sqrt(in_dim / 2.)
    return tf.random_normal(shape=size, stddev=xavier_stddev)
# Generator
def generator(z):
    G_h1 = tf.nn.relu(tf.matmul(z, G_W1) + G_b1)
    G_log_prob = tf.matmul(G_h1, G_W2) + G_b2
    G_prob = tf.nn.sigmoid(G_log_prob)

    return G_prob
# Discriminator
def discriminator(x):
    D_h1 = tf.nn.relu(tf.matmul(x, D_W1) + D_b1)
    D_logit = tf.matmul(D_h1, D_W2) + D_b2
    D_prob = tf.nn.sigmoid(D_logit)

    return D_prob, D_logit

个人总结

  1. GAN是一种框架,核心思想是对抗训练:针对D,希望赋予D(X)高分,赋予D(X’)低分;针对G,希望赋予D(G(z))高分.
  2. 这种对抗训练思想的有效性是通过求解下面的目标函数实现的,求解结果是P_G=P_data
    5.png
  3. 代码实现时,只要能够体现GAN的核心思想即可,使用TensorFlow实现原始GAN模型时,由于TF有simoid_cross_entropy_with_logits这个函数,所以可以使用logistic regression对X和X’进行二分类.此时最大化样本构成的似然函数,相当于最小化样本标签和D输出之间的交叉熵.

最后推荐一下杨双老师的课程,深度学习-GAN专题论文研读,老师讲得非常棒

参考:
杨双:深度学习-GAN专题论文研读
李宏毅对抗生成网络
统计学习方法

猜你喜欢

转载自blog.csdn.net/littlehaes/article/details/81265872