BEGAN: Boundary Equilibrium Generative Adversarial Networks的理解

BEGAN: Boundary Equilibrium Generative Adversarial Networks的理解

这是一篇2017年5月上传到arXiv上的文章,作者是David Berthelot,来自Google。Boundary Equilibrium译作“边界均衡”,文章创新的地方主要有以下几个地方:

  • 应用auto-encoder实现Discriminator
  • Discriminator的Loss_D由输入原图(input_img)与Decoder恢复的输出图(recover_img)之间的逐点error构成
    L ( v ) = | v D ( v ) | ( 1 )

    因而将产生两个Loss_D,分别为真图判别损失Loss_D_real,以及伪图判别损失Loss_D_fake。
  • Loss_D可看成是随机的分布,由real_img所形成的Loss_D_Real分布与由Generator生成的假图(fake_img)所形成的Loss_D_Fake分布,出现了两个分布,用Wasserstein Distance(简称WD)来衡量这两个分布的距离。Discriminator的目标是尽量拉开这两个分布的距离,而Generator的目标是缩小这两个分布的距离——GAN的基本思想。
  • 引入了一个均衡的概念来调节Discriminator训练时的两个目标的比重:目标1,是提高auto-encoder的重构能力,即auto-encoder恢复输入input_img的能力;目标2,提高D的分辨真伪的能力。该均衡控制量是可以变动的,就像是电路中的反馈环,构成了反馈比例控制(Proportional Control)迭代机制。

本文是以WD的出发点来解释和构造GAN的,以下是Wasserstein Distance的定义:

W ( u 1 , u 2 ) = inf γ Γ ( u 1 , u 2 ) E ( x 1 , x 2 ) γ [ | x 1 x 2 | ] ( 2 )

WD本来就是用来衡量两个分布的距离的,知乎上有一篇文章讲得很详细: https://www.zhihu.com/question/39872326?sort=created
在BEGAN中, u 1 u 2 是两个分布, u 1 代表由real_img在Discriminator上生成的Loss_D,即Loss_D_real,而 u 2 代表fake_img在Discriminator上生成的Loss_D,即Loss_D_fake。 W ( u 1 , u 2 ) 便是衡量这两个分布的距离。
(2)式右边是求1次范数均值的下确界。 x 1 是服从 u 1 的随机样本,同理, x 2 是服从 u 2 的随机样本,它们的联合分布服从 γ ,此中有一个约束条件,即是联合分布服 γ 的边沿分布必须是 u 1 u 2 γ 的所有可能形式构成一个概率空间 Γ ( u 1 , u 2 ) ,因此 γ Γ 的一个元素。在 Γ ( u 1 , u 2 ) 中取最小值的那个联合分布 γ 是所求的目标分布,它的期望 E ( x 1 , x 2 ) γ [ | x 1 x 2 | ] 就是所求距离。
作为Discriminator希望此距离越大越好,但最优联合分布 γ 的形式是未知的,因而直接求十分困难,因而需要用可变下界来渐近之,通过Jensen不等式有:
inf E [ | x 1 x 2 | ] inf | E [ x 1 x 2 ] | = | m 1 m 2 | ( 3 )

其中 m 1 m 2 分别是 x 1 x 2 的均值。于是可得 W ( u 1 , u 2 ) 的下界,有:
W ( u 1 , u 2 ) | m 1 m 2 | ( 4 )

将(1)代入有:
{ m 1 = E v u 1 | v D ( v ) | m 2 = E G ( z G ) | G ( z G ) D ( G ( z G ) ) | ( 5 )

要尽量增加距离,只有两种情况:
{ m 1 m 2 0 ( a ) o r { m 1 0 m 2 ( b ) ( 6 )

选(b)较合理,因为当D训练好时,对于真图这边的Error,我们是希望误差越小越好的,即 m 1 趋向0。因而,(4)变形为
W ( u 1 , u 2 ) m 2 m 1 ( 7 )
。尽量提升下界,即求 m 2 m 1 的最大值,但常见的ML后向传递计算搜索的是Loss的最小值,因而在求Loss_D时需要对(7)求反,如下:
L D = m 1 m 2 = E ( L ( x ) ) E ( L ( G ( z D ) ) ) ( 8 )

当经过理想的训练过程后,D应该分辨不出真伪,即: W ( u 1 , u 2 ) 0 ,因而有:
E ( L ( x ) ) = E ( L ( G ( z D ) ) ) ( 9 )

但由于在训练过程中(9)式两边并不匹配,左边会比右边衰减得快,因为Generator的生成过程收敛速度较慢,因而在(8)式右端第二项中添加一个可变的系数对它进行调整,平衡减式两端数值,该系数就是所谓的均衡(Equilibrium)—— k t k t 的调整是一个迭代过程,如同电路的反馈环路, k t 的迭代关系如下:
k t + 1 = k t + λ k ( γ L ( x ) L ( G ( z G ) ) ) ( 10 ) γ = E ( L ( G ( z G ) ) ) E ( L ( x ) ) ( 11 ) L D = E ( L ( x ) ) k t E ( L ( G ( z D ) ) ) ( 12 ) k t  is clamped to [0,1]

(10)式中 λ k 是一个超级参数,取值可以是0.001。(8)式经过均衡的调整,变为了(12)式,这样的目的是让(12)式的前后两项不要相差太大,起到一个制衡(Trade off)的作用。
以下是用pytorch实现的一次训练迭代过程:

# ---------------------
#  Train Discriminator
# ---------------------

optimizer_D.zero_grad()

# Measure discriminator's ability to classify real from generated samples
d_real = discriminator(real_imgs)
d_fake = discriminator(gen_imgs.detach())

d_loss_real = torch.mean(torch.abs(d_real - real_imgs))
d_loss_fake = torch.mean(torch.abs(d_fake - gen_imgs.detach()))
d_loss = d_loss_real - k * d_loss_fake

d_loss.backward()
optimizer_D.step()

#----------------
# Update weights
#----------------

diff = torch.mean(gamma * d_loss_real - d_loss_fake)

# Update weight term for fake samples
k = k + lambda_k * diff.item()
k = min(max(k, 0), 1) # Constraint to interval [0, 1]

还有一篇BEGAN的翻译,可以以看看:https://blog.csdn.net/m0_37561765/article/details/77512692
本文的参考:
1、代码
2、文章


猜你喜欢

转载自blog.csdn.net/StreamRock/article/details/81023212