【GANs学习笔记】(一)初步了解GANs

第一章 初步了解GANs

1. 生成模型与判别模型

      理解对抗网络,首先要了解生成模型和判别模型。判别模型比较好理解,就像分类一样,有一个判别界限,通过这个判别界限去区分样本。从概率角度分析就是获得样本x属于类别y的概率,是一个条件概率Py|x)。而生成模型是需要在整个条件内去产生数据的分布,就像高斯分布一样,需要去拟合整个分布,从概率角度分析就是样本x在整个分布中的产生的概率,即联合概率Pxy)。具体可以参考博文:

      http://blog.csdn.net/zouxy09/article/details/8195017

 

2. 对抗网络思想

       理解了生成模型和判别模型后,再来理解对抗网络就很直接了,对抗网络只是提出了一种网络结构,总体来说, GANs简单的想法就是用两个模型,一个生成模型,一个判别模型。判别模型用于判断一个给定的图片是不是真实的图片(从数据集里获取的图片),生成模型的任务是去创造一个看起来像真的图片一样的图片。而在开始的时候这两个模型都是没有经过训练的,这两个模型一起对抗训练,生成模型产生一张图片去欺骗判别模型,然后判别模型去判断这张图片是真是假,最终在这两个模型训练的过程中,两个模型的能力越来越强,最终达到稳态。(本书仅介绍GANs在计算机视觉方面的应用,但是GANs的用途很广,不单单是图像,其他方面,譬如文本、语音,或者任何只要含有规律的数据合成,都能用GANs实现。)

 

3. 详细实现过程

      假设我们现在的数据集是手写体数字的数据集minst,生成模型的输入可以是二维高斯模型中一个随机的向量,生成模型的输出是一张伪造的fake image,同时通过索引获取数据集中的真实手写数字图片real image,然后将fake imagereal image一同传给判别模型,由判别模型给出real还是fake的判别结果。于是,一个简单的GANs模型就搭建好了。

      值得注意的是,生成模型G和判别模型D可以是各种各样的神经网络,对抗网络的生成模型和判别模型没有任何限制。

3.1 前向传播阶段

1. 模型输入

      1我们随机产生一个随机向量作为生成模型的数据,然后经过生成模型后产生一个新的向量,作为Fake Image,记作D(z)

      2从数据集中随机选择一张图片,将图片转化成向量,作为Real Image,记作x

2. 模型输出

      将由1或者2产生的输出,作为判别网络的输入,经过判别网络后输出值为一个01之间的数,用于表示输入图片为Real Image的概率,real1fake0

      使用得到的概率值计算损失函数,解释损失函数之前,我们先解释下判别模型的输入。根据输入的图片类型是Fake ImageReal Image将判别模型的输入数据的label标记为0或者1。即判别模型的输入类型为(xfake,0)或者(xreal,1)

3.2 反向传播阶段

1. 优化目标

      原文给了这么一个优化函数:

http://www.gwylab.com/images/GANs/math1.png

      我们来理解一下这个目标公式,先优化D,再优化G,拆解之后即为如下两步:

      第一步:优化D

http://www.gwylab.com/images/GANs/math2.png

      优化D,即优化判别网络时,没有生成网络什么事,后面的G(z)就相当于已经得到的假样本。优化D的公式的第一项,使得真样本x输入的时候,得到的结果越大越好,因为真样本的预测结果越接近1越好;对于假样本G(z),需要优化的是其结果越小越好,也就是D(G(z))越小越好,因为它的标签为0。但是第一项越大,第二项越小,就矛盾了,所以把第二项改为1-D(G(z)),这样就是越大越好。

      第二步:优化G

http://www.gwylab.com/images/GANs/math3.png

      在优化G的时候,这个时候没有真样本什么事,所以把第一项直接去掉,这时候只有假样本,但是这个时候希望假样本的标签是1,所以是D(G(z))越大越好,但是为了统一成1-D(G(z))的形式,那么只能是最小化1-D(G(z)),本质上没有区别,只是为了形式的统一。之后这两个优化模型可以合并起来写,就变成最开始的最大最小目标函数了。

      我们依据上面的优化目标函数,便能得到如下模型最终的损失函数。

2. 判别模型的损失函数

      

      当输入的是从数据集中取出的real Iamge 数据时,我们只需要考虑第二部分,D(x)为判别模型的输出,表示输入xreal 数据的概率,我们的目的是让判别模型的输出Dx)的输出尽量靠近1

      当输入的为fake数据时,我们只计算第一部分,Gz)是生成模型的输出,输出的是一张Fake Image。我们要做的是让D(G(z))的输出尽可能趋向于0。这样才能表示判别模型是有区分力的。

      相对判别模型来说,这个损失函数其实就是交叉熵损失函数。计算loss,进行梯度反传。这里的梯度反传可以使用任何一种梯度修正的方法。

      当更新完判别模型的参数后,我们再去更新生成模型的参数。

3. 生成模型的损失函数

      

      对于生成模型来说,我们要做的是让Gz)产生的数据尽可能的和数据集中的数据一样。就是所谓的同样的数据分布。那么我们要做的就是最小化生成模型的误差,即只将由Gz)产生的误差传给生成模型。

      但是针对判别模型的预测结果,要对梯度变化的方向进行改变。当判别模型认为Gz)输出为真实数据集的时候和认为输出为噪声数据的时候,梯度更新方向要进行改变。

      即最终的损失函数为:

      

      其中表示判别模型的预测类别,对预测概率取整,为0或者1.用于更改梯度方向,阈值可以自己设置,或者正常的话就是0.5

4. 反向传播

      我们已经得到了生成模型和判别模型的损失函数,这样分开看其实就是两个单独的模型,针对不同的模型可以按照自己的需要去是实现不同的误差修正,我们也可以选择最常用的BP做为误差修正算法,更新模型参数。

      其实说了这么多,生成对抗网络的生成模型和判别模型是没有任何限制,生成对抗网络提出的只是一种网络结构,我们可以使用任何的生成模型和判别模型去实现一个生成对抗网络。当得到损失函数后就安装单个模型的更新方法进行修正即可。

猜你喜欢

转载自blog.csdn.net/a312863063/article/details/83551569