李宏毅MLDS课程笔记9:Generative Adversarial Network(GAN)

NIPS 2016 Tutorial: Generative Adversarial Networks [Paper] [Video]
tips for training GAN: https://github.com/soumith/ganhacks

Basic Idea of GAN

1、最大化似然函数等价于最小化 Pdata PG 之间的KL散度。如果 PG 取的是GMM的话,产生的image很模糊(GMM与 Pdata 差太多,无法模拟 Pdata )。GAN的好处之一是,用NN来定义 PG PG 更一般化,可以取得更好的效果。
这里写图片描述

2、现在 PG 是一个NN,参数为 θ ,从input distribution 中sample得到input low-dim vector z ,经过NN之后,得到generated distribution PG . 根据 θ 的不同,可以用简单的input distribution 产生各种复杂的distribution。

这样做的问题是难以计算likelihood。
GAN解决了这个问题,在无法计算likelihood的情况下更新 θ ,使 PG 更像 Pdata
解决方法是从天而降一个Discriminator D, 解一个最小最大问题就得到了Generator function(NN) G ,使得input distribution经过Generator function(NN) G 得到的 PG Pdata 最接近。

这个最小最大问题是

G=argminGmaxDV(G,D)

其中
V(G,D)=ExPdata[logD(x)]+ExPG[log(1D(x))]

这样定义的好处是,
maxDV(G,D)=V(G,D)=2log2+2JSD(Pdata(x)||PG(x))
衡量了 PG Pdata 之间的Jensen-Shannon散度(与KL散度不同,JS散度是对称的)。

现在我们已经把 maxDV(G,D) 搞定了,剩下的问题是如何求解

G=argminGmaxDV(G,D)=argminGJSD(Pdata(x)||PG(x))

方法(梯度下降):
1、初始化 G0
2、得到 D0 , V(G0,D0)=JSD(Pdata(x)||PG0(x))
3、用 V(G,D0) θG 的梯度来更新 θG ,得到 G1 . 有

V(G1,D0)<V(G0,D0)

然而却不一定有
JSD(Pdata(x)||PG1(x))=V(G1,D1)<V(G0,D0)=JSD(Pdata(x)||PG0(x))

4、得到 D1 , V(G1,D1)=JSD(Pdata(x)||PG1(x))
5、用 V(G,D1) θG 的梯度来更新 θG ,得到 G2 .有
V(G2,D1)<V(G1,D1)

然而却不一定有
JSD(Pdata(x)||PG2(x))=V(G2,D2)<V(G1,D1)=JSD(Pdata(x)||PG1(x))

……

为了尽量避免出现上面的“不一定有”的情况,对G每次不能更新太多。

扫描二维码关注公众号,回复: 1677127 查看本文章

3、实际操作中
无法用积分计算V(G,D)中的期望,通过采样的方法得到 V~ 来近似 V 。而 V~ 的形式与Binary Classifier的loss function形式相同。这是符合直觉的,因为我们要得到的Discriminator D就是一个Binary Classifier。
这里写图片描述
(图中L整体缺少一个负号)

所以实际中的算法是:
这里写图片描述
其中, Pprior(z) 是自己定的简单分布。
之所以要把Learning D的过程重复多次是因为每次得到的不是 maxDV(G,D) ,而是 maxDV(G,D) 的一个lower bound,重复多次可以使lower bound变大。
之所以把Learning G的过程只进行一次,原因就是上面说过的每次更新G的时候不能更新太多,以免JSD不降反增。

另外,实际中在Learning G的时候,目标函数也有所改变:
这里写图片描述
这是因为在开始的时候D(x)很小,此时目标函数的微分小,所以训练慢。
改了目标函数之后,在D(x)很小的时候训练速度变快,在D(x)接近1(我们的目标)的时候训练速度慢下来。

Issue about Evaluating the Divergence

实际中遇到一个问题,从discriminator的loss中无法看出生成图片的质量是否变好,因为loss总是基本为0,即discriminator认为 Pdata PG 完全没有overlap。
这有两个原因。
一是discriminator过于强大,将 Pdata PG 的采样点用复杂的边界区分开。
若是要减弱discriminator的话(update次数少一点、dropout、用比较少的参数),不知道discriminator要调到什么地步才能得到好的结果。而且,discriminator可以量JSD的前提是,discriminator可以是任何function,因此又希望discriminator能powerful一些。
二是 Pdata PG 本身就没有很多overlap。
解决方法是加噪声:
这里写图片描述

Mode Collapse

Conditional GAN

待续

猜你喜欢

转载自blog.csdn.net/xzy_thu/article/details/70505141