深度学习-->Improved GAN-->f-GAN

上一篇博文中详细总结和推导了 GAN 网络的原理,但是如此的 GAN 网络有他的不足之处,本博文将详细说明其不足之处,以及解决和提高的办法。

original GAN 不足之处

简单回顾GAN网络原理

这里写图片描述

蓝色的线表示: Generated distribution
绿色的线表示: Data(target) distribution
红色的线表示: Discriminator

在上图中的左上第一个子图中, generator 生成的数据分布与 Data distribution 相差较大,则 Discriminator 也即是 D(x) Generated distribution 以较低的概率,而给 Data(target) distribution 以较高的概率,由此得到 D(X) 的曲线走向。在左二子图中,更新后的 generator 可能会因为更新步伐太大,移到了 Data distribution 的右边,由此 D(X) 更新如图, GD 如此不断的更新迭代,最终 Generated distribution Data(target) distribution 重合,那么此时 D(X) 就变成了一条水平直线。

存在的问题

我们知道整个 GAN 网络的目标都是在:

通过不断的更新 DG 来得到比较好的 Generator ,也就是上式的 G ,那么在更新 D 时:

max V(G,D) = 2log2 + 2JSD(Pdata(X)||PG(X))

我们是不断的通过 Minize max V(G,D) 来更新 G ,那么问题来了,这个 max V(G,D) 是否能准确的反映 Pdata PG 之间的差距呢?

这里写图片描述

由上图可以看出,当 PG Pdata 无重合时(可能是sample出的样本没有重合),即使两者的 distribution 在改进,其 JS(PG||Pdata) 始终为 log2 ,那么在更新 G 参数时,没有改进的动力。很难得到很好的 Generator

Unified Framework

f-divergence

之前我们介绍的 GAN 网络中的 Discriminator 只是和 JensenShannon divergence 有关,论文 Training Generative Neural Samplers using Variational Divergence Minimization 中介绍了 fdivergence ,其 Discriminator 不只是仅仅由 JensenShannon divergence 来定义,其核心的一句话就是 you can use any fdivergence

我们假设有两个分布,分别 p q ,代入到 GAN 网络中,就是之前说的 Pdata PG ,其中 p(x) q(x) 就是 sample 出来的样本的概率。由此我们可以这样来定义 fdivergence

Df(P||Q)=xq(x) f(p(x)q(x))dx

显然这样定义的 Df(P||Q) 必须能起到衡量 PQ 分布的拟合程度,并且值越小拟合的越好。那么就必须具备以下条件:

  • f 函数必须是凸的
  • f(1)=0

那么可以得到,当对于所有的 x 都有 P(x)==Q(x) 时: Df(P||Q)=0 ,这个时候显然拟合的最好,并且是 smallest Df(P||Q)

再由凸函数的特性可得:

Df(P||Q)=xq(x) f(p(x)q(x))dxf(xq(x) p(x)q(x)dx)=fxp(x)=1

故可得到 Df(P||Q)1

其实 KL divergence 就可以理解为一种 fdivergence 。那么 f 可以选哪些函数呢?只要符合上面的要求即可:

这里写图片描述

Fenchel Conjugate

首先假设 f(x) 是一个凸函数,定义如下公式:

f(t)=maxxdom(f){xtf(x)}

得到 f(t) ,这里固定住不同的 x(x1,x2,..) ,都能得到不同的关于 t 的线性函数,其图可以如下:

这里写图片描述

然后取其 max ,就得到上图红色的那条线。由此可以得到一个结论:

f(x)f(t)

我们把这样的 f(t) 叫做 f(x) conjugate function

举个具体的例子,当 f(x)=xlogx 时,可得 f(t)

这里写图片描述

那么如何得出当 f(x)=xlogx 时, f(t) 的具体数学公式呢?

这里写图片描述

那么可得结论,当 f(x)=xlogx 时,其 conjugate function f(t)=exp(t1) ,也即:

f(x)=xlogxf(t)=exp(t1)

这里需要注意: (f)=f

Connect to GAN

那么上面讲的与 GAN 有什么关系呢?
假设 f(x)f(1)=0 ,则由上面的推导我们可以得出:

f(t)=maxxdom(f){xtf(x)}f(x)=maxtdom(f){txf(t)}

f(x)=maxtdom(f){txf(t)} 中,我们可以令 x=p(x)q(x) ,再由上面已经得出的 fdivergence 条件可得:

Df(P||Q)=xq(x) f(p(x)q(x))dx=xq(x)(maxtdom(f){p(x)q(x)tf(t)})dx

这里可以假设存在某一个函数 D ,其输入为 x ,输出为 t ,则有:

这里写图片描述

注意:不论函数 D 为何函数,其符号都为大于等于。

那么我们可以选择到某个函数 D ,使其上式右边取最大,则可得如下:

Df(P||Q)maxDxP(x)D(x)dxxq(x)f(D(x))dx

Df(P||Q) 表示 fdivergence ,上面我们已经说明了 fdivergence 可以用来衡量两种分布的拟合程度。

继续推导可得:

这里写图片描述

得出的形式是不是很像上一博文中介绍的 V(G,D) 函数?

V = ExPdata[logD(x)] + ExPG[log(1D(x))]

继续可得:

这里写图片描述

所以我们可以这样理解更新 G 的过程,实际就是不断的减小 fdivergence ,而这个时候 fdivergence 直接就是用来衡量两种分布的拟合程度。

实际train的不同

这里写图片描述

original GAN 中,在 Inner loop 中通过多次循环来更新 D ,然后再更新 G ;而在上面介绍的 fGan 中,只需要一步即可更新 DG

由此我们可以选中任意一种 fdiveragence minize

猜你喜欢

转载自blog.csdn.net/mr_tyting/article/details/79342828
今日推荐