Deep Generative Model2--GAN(Generative Adversarial Network)-cont2

Theory behind GAN

这篇文章同样follow李宏毅老师的机器学习课程关于GAN理论的讲解。也可以算是一份学习笔记。OK,我们开始。

GAN,作为一种生成式的模型,我们想有的是构造一个生成器Generator,使其背后的数据分布P_G(x) 和真实的数据分布P_data(x)尽可能的接近,这样就可以让机器生成的数据(比如image)和真实的图片数据尽可能地相似。

所以我们的目标就是让机器去学到这个P_G(x)。但是现在的一个问题就是,我们根本不知道真实的分布P_data(x)长什么样,只知道它是一个非常复杂的高维分布,那么该怎么做呢?直觉的想法就是我们从真实的分布中抽样,得到{x1,x2,…xm},然后使用MLE的思想,通过调整P_G(x)的参数,使得通过P_G(x)产生{x1,x2,…xm}的概率最大,这样可以让P_G(x)向真实分布尽可能靠拢。

这里我们再做进一步的思考,我们通过MLE找到的P_G(x)的参数,记做theta*,事实上就是使得P_G和P_data的KL Divergence最小的那个theta。这个结论非常符合我们的理解。下面是具体的推导
在这里插入图片描述
但是我们再来回顾一下之前提到的一个点,那就是往往在真实的情况中,我们的真实分布P_data(x)是非常复杂的,那么如果要让我们的P_G去尽可能接近它,那么也就意味着P_G也是非常复杂的,比如说G是一个多个hidden layer组成的deep neural network,那么我们就很难计算likelihood。

好,紧接着上面的铺垫,那么现在我们进入Generator的讨论。在GAN中,G是一个neural network。我们的input可以是从一个简单分布(比如高斯分布)中抽样得到的一系列随机变量,记做z,然后通过我们的G,生成的一系列x=G(z)形成的分布,就是我们的P_G(x)。接下来我们的目标就是让我们的P_G(x)和真实的数据分布P_data(x)越接近越好。
如下图
在这里插入图片描述
但是我们想要获得这个G*,就需要去计算Divergence,但是问题是,这个Divergence到底要如何计算呢?我们并不知道P_G和P_data。那么GAN是如何解决的呢?方法其实就是去抽样,通过从P_data和P_G中抽样并使用Discriminator来测量Divergence。

那我们来仔细地看一下。Discriminator,下文记做D,可以近似地看作一个分类器。它的目标就是尽可能地区分出真实的数据点和由G生成的伪数据。它的objective function如下
在这里插入图片描述
非常容易理解,目标最大化V,对于第一项由P_data生成的数据,给予尽可能大的概率使其标签为1,对于第二项由P_G生成的数据x ,要使其通过D输出的概率尽可能接近于0,从而使1-D(x)尽可能大。
在这里插入图片描述
事实上,D的目标函数与JS Divergence有着非常密切的联系。下面开始枯燥的数学证明…
根据上面给出的V的表达式,我们将其转化为积分的形式,如下
在这里插入图片描述
但是这里的一个问题是,如果我们真想获得Max V,这里的D(x)应该是可以取任意的函数,但是事实上,我们的D(x)是一个network,也就是说存在一定的限制。但是我们这里还是假设D可以取任意函数形式。
那么现在我们的任务就变成了在给定x的情况下找到D使得下面的式子最大
在这里插入图片描述
通过微分,我们最终得到对D
(x)为如下的表达式
在这里插入图片描述
现在我们把D*(x)的表达式代回到V(D,G)中,最终可以转化成JS Divergence
在这里插入图片描述
根据上面的推导,我们巧妙地把前面求解G中需要用到Divergence的表达式部分替换成maxV(G,D),也就是说,求解G变成了如下的形式
在这里插入图片描述
那么具体该怎么解这个最优化问题呢?我们用到的算法就是交替优化G和D。关于具体为什么这个最优化问题可以通过交替优化G和D来实现,大家可以去看一下李宏毅老师的视频,这里因为时间原因就不做具体说明了。OK,假设我们已经明白了原因,我们最后来仔细看一下交替优化算法的具体细节。
理想状态下,我们的V函数的形式如下
在这里插入图片描述
但是事实上,我们并不可能真的知道期望,因此这里我们还是通过抽样来近似。OK,接下来我们来看一下具体的算法流程,如下图
在这里插入图片描述
首先我们先初始化G和D。
在每个循环中,我们首先训练D,从P_data中抽样得到m个样本点{x1,x2,…xm}。从某个简单分布中抽样得到m个样本{z1,z2,…zm},再通过当前的G,映射得到{x1,x2,…x~m},D需要做的就是将这两类数据尽可能区分开,通过不断更新参数theta_d来优化目标函数。这个过程需要持续多次,来使D尽可能真正靠近D来使V被最大化。只有这样,我们才能真正用这部分来替代JS Divergence。
然后我们训练G,并更新theta_G,目标使得G生成的数据能够尽可能骗过D。但是这里需要注意的一个重点是G的更新不能太多,一般只更新一次,为什么需要这么做呢?原因在于G的更新会影响V(G,D),导致原先的D
不再maxV了,那么这样也就意味着不能再被当作JS Divergence了。因此我们必须使G的更新不会太大,使得V(G,D)的变化不会太大。关于这一点,大家也可以仔细看一下李宏毅老师视频的讲解,这里只是大概地描述一下。

参考:https://www.bilibili.com/video/BV1JE411g7XF?p=75

猜你喜欢

转载自blog.csdn.net/weixin_44607838/article/details/111711639