上一篇博文中详细总结和推导了
GAN
网络的原理,但是如此的
GAN
网络有他的不足之处,本博文将详细说明其不足之处,以及解决和提高的办法。
original GAN 不足之处
简单回顾GAN网络原理
蓝色的线表示:
Generated distribution
绿色的线表示:
Data(target) distribution
红色的线表示:
Discriminator
在上图中的左上第一个子图中,
generator
生成的数据分布与
Data distribution
相差较大,则
Discriminator
也即是
D(x)
给
Generated distribution
以较低的概率,而给
Data(target) distribution
以较高的概率,由此得到
D(X)
的曲线走向。在左二子图中,更新后的
generator
可能会因为更新步伐太大,移到了
Data distribution
的右边,由此
D(X)
更新如图,
G、D
如此不断的更新迭代,最终
Generated distribution
与
Data(target) distribution
重合,那么此时
D(X)
就变成了一条水平直线。
存在的问题
我们知道整个
GAN
网络的目标都是在:
通过不断的更新
D、G
来得到比较好的
Generator
,也就是上式的
G∗
,那么在更新
D
时:
max V(G,D) = −2log2 + 2JSD(Pdata(X)||PG(X))
我们是不断的通过
Minize max V(G,D)
来更新
G
,那么问题来了,这个
max V(G,D)
是否能准确的反映
Pdata
与
PG
之间的差距呢?
由上图可以看出,当
PG
与
Pdata
无重合时(可能是sample出的样本没有重合),即使两者的
distribution
在改进,其
JS(PG||Pdata)
始终为
log2
,那么在更新
G
参数时,没有改进的动力。很难得到很好的
Generator
。
Unified Framework
f-divergence
之前我们介绍的
GAN
网络中的
Discriminator
只是和
Jensen−Shannon divergence
有关,论文
Training Generative Neural Samplers using Variational Divergence Minimization
中介绍了
f−divergence
,其
Discriminator
不只是仅仅由
Jensen−Shannon divergence
来定义,其核心的一句话就是
you can use any f−divergence
我们假设有两个分布,分别
p
和
q
,代入到
GAN
网络中,就是之前说的
Pdata
和
PG
,其中
p(x)
和
q(x)
就是
sample
出来的样本的概率。由此我们可以这样来定义
f−divergence
:
Df(P||Q)=∫xq(x) f(p(x)q(x))dx
显然这样定义的
Df(P||Q)
必须能起到衡量
P、Q
分布的拟合程度,并且值越小拟合的越好。那么就必须具备以下条件:
那么可以得到,当对于所有的
x
都有
P(x)==Q(x)
时:
Df(P||Q)=0
,这个时候显然拟合的最好,并且是
smallest Df(P||Q)
。
再由凸函数的特性可得:
Df(P||Q)=∫xq(x) f(p(x)q(x))dx≥f(∫xq(x) p(x)q(x)dx)=f∫xp(x)=1
故可得到
Df(P||Q)≥1
。
其实
KL divergence
就可以理解为一种
f−divergence
。那么
f
可以选哪些函数呢?只要符合上面的要求即可:
Fenchel Conjugate
首先假设
f(x)
是一个凸函数,定义如下公式:
f∗(t)=maxx⫅dom(f){xt−f(x)}
得到
f∗(t)
,这里固定住不同的
x例如(x1,x2,..)
,都能得到不同的关于
t
的线性函数,其图可以如下:
然后取其
max
,就得到上图红色的那条线。由此可以得到一个结论:
若f(x)为凸函数,则必存在其对应的f∗(t),且也是凸函数
我们把这样的
f∗(t)
叫做
f(x)
的
conjugate function
举个具体的例子,当
f(x)=xlogx
时,可得
f∗(t)
:
那么如何得出当
f(x)=xlogx
时,
f∗(t)
的具体数学公式呢?
那么可得结论,当
f(x)=xlogx
时,其
conjugate function f∗(t)=exp(t−1)
,也即:
f(x)=xlogx↔f∗(t)=exp(t−1)
这里需要注意:
(f∗)∗=f
Connect to GAN
那么上面讲的与
GAN
有什么关系呢?
假设
f(x)为凸函数,且f(1)=0
,则由上面的推导我们可以得出:
f∗(t)=maxx⫅dom(f){xt−f(x)}↔f(x)=maxt⫅dom(f∗){tx−f∗(t)}
在
f(x)=maxt⫅dom(f∗){tx−f∗(t)}
中,我们可以令
x=p(x)q(x)
,再由上面已经得出的
f−divergence
条件可得:
Df(P||Q)=∫xq(x) f(p(x)q(x))dx=∫xq(x)(maxt⫅dom(f∗){p(x)q(x)t−f∗(t)})dx
这里可以假设存在某一个函数
D
,其输入为
x
,输出为
t
,则有:
注意:不论函数
D
为何函数,其符号都为大于等于。
那么我们可以选择到某个函数
D
,使其上式右边取最大,则可得如下:
Df(P||Q)≈maxD∫xP(x)D(x)dx−∫xq(x)f∗(D(x))dx
Df(P||Q)
表示
f−divergence
,上面我们已经说明了
f−divergence
可以用来衡量两种分布的拟合程度。
继续推导可得:
得出的形式是不是很像上一博文中介绍的
V(G,D)
函数?
V = Ex∼Pdata[logD(x)] + Ex∼PG[log(1−D(x))]
继续可得:
所以我们可以这样理解更新
G
的过程,实际就是不断的减小
f−divergence
,而这个时候
f−divergence
直接就是用来衡量两种分布的拟合程度。
实际train的不同
在
original GAN
中,在
Inner loop
中通过多次循环来更新
D
,然后再更新
G
;而在上面介绍的
f−Gan
中,只需要一步即可更新
D、G
。
由此我们可以选中任意一种
f−diveragence
去
minize
。