致敬GAN与我最喜欢的框架pytorch

小编从17年暑假开始进入实验室学习,自学了深度学习与机器学习,但理解并不深刻;18年暑假开始从一个师姐手中接下一个课题“线条简化”,但其实做的工作主要是数据集标注与跑实验,尽管最后稍稍改些代码并在论文里提供了几张图,最后的犒劳是“五作(仵作)”;好吧,果然人还是要强大起来,才能获得主动权;19年已经过去了一半,小编才在前不久才更深层次理解了GAN与基于pytorch的实现,复现了2018年传说中的CartoonGAN-pytorch,不才不才。

(一)GAN的基本原理

GAN之所以起作用是GAN中的生成器G与鉴别器D内部的相互对抗,G通过不断提高自己的生成能力(在图像转换任务中就是图像转换能力) ,将源域SRC的数据样本X映射为Y',试图瞒过D,让D误以为是Y'就是目标域TAR的样本Y。

在宏观上(整个数据域TAR或者SRC上),就是希望G将X所在的数据分布映射为Y'的数据分布,使得Y'与Y的数据分布式近似的。

(二)数学表达

生成器G一般被设定是一个Encoder-Decoder的模型,在CV领域,其作用往往是根据输入的图像(或噪声)输出一张图像;鉴别器的作用则就是一个二分类器,判别True或者False。

一般我们有:x\overset{G}{\rightarrow}z=G(x)\overset{D}{\rightarrow}True /False

在分类任务中,我们常用的loss是交叉熵损失;对于二分类,我们习惯使用二元交叉熵(BCE: Binary-entropy loss)。定义如下:

对G,希望合成的图片G(x)可以瞒过检测器D的检测,因此我们希望D(G(x))的响应越大(接近于1);即——

L_{G}=log(1-D(G(x)))

其中G(x)越像,D(G(x))越大,1-D(G(x))越小,log后的值也就越小(接近负无穷)。

对D,希望它火眼金睛,可以识别出哪些图像是真的来自TAR域,哪些是冒牌的,因此我们希望D(y)响应值越大(接近于1),而D(G(x))响应值越小(接近于0);即——

L_{D}=-log(D(G(y))-log(1-D(G(x))).

(三)结合pytorch脚本的GAN更新策略

一下的内容来自于我参考的两篇很好的博客,我觉得应该再尝试着复述一下我才更印象深刻。

1. 策略1:先更新判别器D,后更新生成器G。

"""Updating in a single iteration for GAN-training in pytorch"""
for epoch in range(EPOCH):
    for i in range(ITERATION):
        #### 定义用于计算对抗损失的两个目标(1和0)
        valid = torch.Tensor(imgs.size(0), 1).fill_(1.0).to(device) # 真标签,都是1
        fake = torch.Tensor(imgs.size(0), 1).fill_(0.0).to(device)  # 假标签,都是0

        #### Train D
        optimizer_D.zero_grad() # 把判别器中所有参数的梯度归零

        ## Train D with real data (Y)
        real_imgs = imgs.to(device)
        pred_real = discriminator(real_imgs)               # 判别器对真数据的输出
        real_loss = adversarial_loss(pred_real, valid)     # 判别器对真实样本的损失
        
        ## Train with fake data (Y')
        z = torch.randn((imgs.shape[0], 100)).to(device)   # 噪声
        gen_imgs = generator(z)                            # 从噪声中生成假数据
        pred_gen = discriminator(gen_imgs)                 # 判别器对假数据的输出
        fake_loss = adversarial_loss(pred_gen, fake)       # 判别器对假样本的损失

        d_loss = (real_loss + fake_loss) / 2               # 两项损失相加取平均

        # 下面这行代码十分重要,将在正文着重讲解
        d_loss.backward(retain_graph=True)                 # 计算权重梯度;retain_graph 十分重要,显示声明保留计算图;否则计算图内存将会被释放
        optimizer_D.step()                                 # 判别器参数更新

        #### Train G
        g_loss = adversarial_loss(pred_gen, valid)         # 生成器的损失函数
        optimizer_G.zero_grad()                            # 生成器参数梯度归零
        g_loss.backward()                                  # 反向传播计算生成器的损失函数梯度
        optimizer_G.step()                                 # 生成器参数更新
    
        # end of the iteration
    # end of the epoch

讨论之前,我们明确:

  • 计算图被用来记录一个计算的过程,这个过程中,方形表示“运算”,原型表示“变量”,一个“变量”包括数据与权重。沿着计算图前向传播可以得到结果与各个中间变量;逆向传播可以计算各个节点的变量的梯度!其中,一个计算图有且仅有一次被用来反向传播BP。换句话说,就是一个计算图如果在反向传播后还存在,那就能防止在当前迭代中对这一部分计算做第二次反向传播求梯度后梯度更新。
  • 对于同一个计算图,在同一次迭代过程中由多组数据流经过,不同数据流计算的loss叠加后,其反向传播也仅是计算一次,相当于是梯度的累加!

训练鉴别器时——在计算D(y)时,计算图包括了D的整个前向过程;计算D(G(x))时,计算图包括了G和D的整个前向过程。由于d_loss包括了real_loss=criterion( D(y), True_label )和fake_loss=criterion( D(G(x)), False_label ),因此反向传播时,对G和D分别做了一次反向传播。但是,注意到我们后面只有optimizer_D做了梯度更新。我们知道pytorch中的优化器初始化的时候会“安排”它负责哪些module的参数更新,其他的模块它就不管了。因此,此次 optimizer_D.step() 仅更新了D的梯度一次!

可以看到,此次反向计算了D和G的梯度,但仅对D做梯度更新。由于G没有更新梯度,因此它的计算图部分被保留了下来。

训练生成器时——直接使用前面计算得到的D(G(x)),不需要重复计算,计算图包括了D的整个前向过程,在loss反向传播时,必然需要对D和G都做一次反向求梯度。这就是为什么我们要在loss_D.backward()中声明保留计算图,因为后面的 generator 算梯度时还要用到G的这部分计算图,所以用这个参数控制计算图不被释放。当然,注意到我们这里的脚本写的是:训练D时用的fake数据和训练G时用的fake数据是同一组数据;就是说,假如你训练D时用的fake数据和训练G时用的fake数据不同时(分别初始化后经过G生成),就不需要了哈,因为G的计算图会再次生成。

另一方面,我们看到在loss_G反向传播之前,要先声明optimizer_G.zero_grad(),将上一次loss_D反向时计算的梯度归零(不然会叠加在一起,而考虑到我们先训练D,就是因为G的更新依赖于D,所以先训练D时计算的G的梯度是没有任何根据的,价值不大)。这之后我们用optimizer_G.step()仅实现了G的梯度更新。

综上,在这个策略中,我们对D和G都做了两次反向传播(计算了两次梯度)——第一次传播为了更新D的参数,但不得不额外计算G的梯度;第二次传播是为了更新G的参数,但不得不额外计算D的梯度。

2. 策略2:先更新生成器G,后更新判别器D。

"""Updating in a single iteration for GAN-training in pytorch"""
for epoch in range(EPOCH):
    for i in range(ITERATION):
        #### 定义用于计算对抗损失的两个目标(1和0)
        valid = torch.Tensor(imgs.size(0), 1).fill_(1.0).to(device) # 真标签,都是1
        fake  = torch.Tensor(imgs.size(0), 1).fill_(0.0).to(device) # 假标签,都是0

        real_imgs = Variable(imgs.type(Tensor))                     # 真实数据 y

        #### 训练生成器
        optimizer_G.zero_grad()                                     # 生成器参数梯度归零
        z = Variable(Tensor(np.random.normal(0, 1, (imgs.shape[0], opt.latent_dim)))) # 随机噪声
        gen_imgs = generator(z) # 根据噪声生成虚假样本
        g_loss = adversarial_loss(discriminator(gen_imgs), valid)   # 用真实的标签+假样本,计算        生成器损失
        g_loss.backward()                                           # 生成器梯度反向传播,反向传播经过了判别器,故此时判别器参数也有梯度
        optimizer_G.step()                                          # 生成器参数更新,判别器参数虽然有梯度,但是这一步不能更新判别器

        #### 训练判别器
        optimizer_D.zero_grad()                                     # 把生成器损失函数梯度反向传播时,顺带计算的判别器参数梯度清空
        real_loss=adversarial_loss(discriminator(real_imgs), valid) # 真样本+真标签:判别器损失
        fake_loss = adversarial_loss(discriminator(gen_imgs.detach()), fake) # 假样本+假标签:判别器损失;注意这里的".detach()"的使用
        d_loss = (real_loss + fake_loss) / 2                        # 判别器总的损失函数
        d_loss.backward()                                           # 判别器损失回传
        optimizer_D.step()                                          # 判别器参数更新
        
        # end of the iteration
    # end of the epoch

分析发现:

在训练G时,计算gen_imgs = G(z)时生成了覆盖G前向过程的计算图,计算g_loss = criterion(G(z), True)时生成了覆盖D前向过程的计算图,g_loss反向传播时对D和G都计算了梯度;但是我们只使用optimizer_G更新G的梯度。完了后D和G的计算图被释放。

在训练D时,计算D(y)时生成了仅包括D的计算图,计算D(G(z))时则是在刚刚生成的D的计算图上又过了一遍。在分享传播时,real_loss反向传播计算了D的梯度1次;紧接着loss_fake想要反向传播,但是,当它往回走走到D的输入位置时,发现前方无路可走了,因为计算它的G的计算图被释放了,因此,我们需要显示告诉它梯度更新到此处即可,就是通过“G(z).detach()”实现的。detach 的意思是,这个数据和生成它的计算图“脱钩”了,即梯度传到它那个地方就停了,不再继续往前传播。

综上,在此策略中,我们对G做了一次反向传播,对D做了两次次反向传播。并且不需要专门在内存中保留G的计算图。

 

【总结】

策略1的好处是:noise只进行了一次前向传播,缺点是需要对D和G都做两次反向传播,还需要在内存中保留计算图(D+G)。

策略2的好处是:先更新G,使得更新后前向传播的计算图(D+G)可以被放心销毁,不用占用太多内存;后面更新D,显然需要再一次产生新的计算图,不过这次只包括D,相对策略1较小;同时这是对D作第2次前向传播,同理也就需要做第2次反向传播。

前者是多了一次对G反向传播求梯度;后者是多了一次对D的前向传播。如果D比较复杂,应该采取策略1;反之则应该采取策略2.而通常情况下,D是要比G简单得多的,故应该采取策略2居多。

(最后一句话来自知乎上的原文)但是第二种先更新generator,再更新 discriminator 总是给人感觉怪怪得,因为 generator 的更新需要 discriminator 提供准确的 loss 和 gradient,否则岂不是在瞎更新?

 

【参考】

1. Pytorch: detach 和 retain_graph,和 GAN的原理解析

2. Pytorch: detach 和 retain_graph

猜你喜欢

转载自blog.csdn.net/WinerChopin/article/details/95986694
今日推荐