Differentiable Augmentation for Data-Efficient GAN Training

Differentiable Augmentation for Data-Efficient GAN Training

在训练数据量有限的情况下,生成式对抗网络(GANs)的性能会严重恶化。这主要是因为判别器正在记忆准确的训练集。为了解决这个问题,我们提出了可微分增强(DiffAugment),这是一个简单的方法,通过对真实和虚假样本施加各种类型的可微分增强来提高GANs的数据效率。以前直接augment训练数据的尝试操纵真实图像的分布,收效甚微;DiffAugment使我们对生成的样本采用可微增广,有效地稳定了训练,并使其收敛得更好

实验表明,我们的方法比各种GAN架构和损失函数在无条件和有条件的生成方面都有一致的收益。通过DiffAugment,我们在ImageNet 128×128上实现了最先进的FID为6.80,IS为100.8,在FFHQ和LSUN上,给定1000张图像的FID降低了224倍。此外,只需20%的训练数据,我们就能在CIFAR-10和CIFAR-100上达到顶级性能。最后,我们的方法只需使用100张图像就能生成高保真的图像,无需预训练,同时与现有的迁移学习算法相当。

Code is available at https://github.com/mit-han-lab/data-efficient-gans.

1 Introduction

大数据使深度学习算法取得了快速进展。特别是,最先进的生成对抗网络[11]能够生成不同类别的高保真自然图像[2,18]。许多计算机视觉和图形应用已经成为可能[32,43,54]。然而,这种成功是以大量的计算和数据为代价的。最近,研究人员提出了一些有前景的技术来提高模型推断的计算效率[22,36],而数据效率仍然是一个根本性的挑战。

GANs在很大程度上依赖于大量不同的、高质量的训练实例。仅举几例,FFHQ数据集[17]包含70,000张经过选择的高分辨率人脸图像;ImageNet数据集[6]对超过一百万张图像进行了各种物体类别的注释。收集这样大规模的数据集需要几个月甚至几年的大量人力,以及令人望而却步的注释成本。在某些情况下,甚至不可能有这么多的例子,例如,稀有物种的图像或特定人物或地标的照片。因此,消除巨大的数据集对GAN训练的需求是至关重要的。

在这里插入图片描述

图1:在数据量有限的情况下,BigGAN严重恶化。左边。在10%的CIFAR10数据中,FID在训练开始后不久就增加了,然后模型就崩溃了(红色曲线)。中间:判别器D的训练精度很快就饱和了。右边:D的验证精度急剧下降,表明D已经记住了准确的训练集,无法泛化。

然而,减少训练数据量会导致性能的急剧下降。例如,在图1中,只给出10%或20%的CIFAR-10数据,鉴别器的训练准确率很快饱和(接近100%);然而,它的验证准确率一直在下降(低于30%),这表明鉴别器只是简单地记忆了整个训练集。这种严重的过拟合问题破坏了训练动力学,导致图像质量下降。在图像分类中减少过拟合的一个广泛使用的策略是数据增强[20, 38, 42],它可以增加训练数据的多样性,而无需收集新的样本。诸如裁剪、翻转、缩放、颜色抖动[20]和区域遮蔽(Cutout)[8]等变换是常用的视觉模型增强方法。

然而,将数据增强应用于GANs是根本性的不同。如果转换只是添加到真实的图像上,生成器将被鼓励去匹配增强的图像的分布。因此,输出会受到分布偏移和引入的人工增强的影响(例如,一个区域被掩盖,不自然的颜色,见图5a)。另外,我们可以在训练判别器时同时增强真实图像和生成的图像;然而,这将打破生成器和判别器之间的微妙平衡,导致收敛性差,因为它们的目标完全不同(见图5b)。

在这里插入图片描述

图5:理解为什么vanilla 增强策略会失败。(a) "仅增强真实数据 "模仿了与增强所引入的相同的数据失真,例如平移填充、Cutout方块和颜色伪影;(b) "仅增强D "由于不平衡的优化而出现分歧–D对增强后的图像(包括T(x)和T(G(z))进行完美分类,但几乎无法识别G(z)(即没有增强的假图像),而G从这些图像中获得梯度。

为了解决这个问题,我们引入了一种简单而有效的方法–DiffAugment,它将相同的可微调增强应用于真实和虚假图像的生成器和判别器训练。它使梯度通过增强传播到生成器,在不操纵目标分布的情况下使判别器正规化,并保持训练动态的平衡。

在各种GAN架构和数据集上的实验不断证明了我们方法的有效性。通过DiffAugment,我们改进了BigGAN,在ImageNet 128×128上实现了6.80的Frechet ´Inception Distance(FID)和100.8的Inception Score(IS),并且在FFHQ和LSUN数据集上给定1000张图像,将StyleGAN2基线的FID降低了2-4倍。我们还在CIFAR-10和CIFAR-100上只用了20%的训练数据就达到了顶级性能(见图2和图10)。此外,我们的方法仅用100个例子就能生成高质量的图像(见图3)。在没有任何预训练的情况下,我们实现了与现有的转移学习算法相竞争的性能,这些算法过去需要数万张训练图像

2 Related Work

Regularization for GANs GAN训练通常需要额外的正规化,因为它们高度不稳定。为了稳定训练动力学,研究人员提出了几种技术,包括instance noise[39]、Jensen-Shannon正则化[34]、梯度惩罚[12,27]、光谱归一化[28]、对抗防御正则化[53]和一致性正则化[50]。所有这些正则化技术都隐含地或明确地惩罚了判别器在输入的局部区域内输出的突然变化。在本文中,我们提供了一个不同的视角,即数据增强,我们鼓励判别器在不同类型的增强下表现良好。在第4节中,我们表明我们的方法在实践中与正则化技术是互补的。

Data Augmentation. 许多深度学习模型采用保留标签的变换来减少过拟合:例如颜色抖动[20]、region masking[8]、翻转、旋转、cropping[20,42]、数据混合[47]、local and affine distortion[38]。最近,AutoML[40,55]被用于探索针对给定数据集和任务的自适应增强策略[4,5,23]。然而,将数据增强应用于生成模型(如gan)仍然是一个悬而未决的问题。与分类器训练(其中标签对输入的转换是不变的)不同,生成模型的目标是学习数据分布本身。直接应用增广将不可避免地改变分布。我们提出了一个简单的策略来规避上述问题。在我们工作的同时,有几种方法[16,41,52]独立提出了训练GANs的数据增强方法。我们建议读者查阅他们的工作以了解更多细节。

3 Method

生成对抗网络(GANs)的目的是通过生成器G和鉴别器d对目标数据集的分布进行建模。生成器G将输入潜在向量z(通常来自高斯分布)映射到其输出G(z)。鉴别器D学习将生成的样本G(z)与真实观测值x区分开来。标准GANs训练算法在给定损耗函数fD和fG的情况下,交替优化鉴别器的损失LD和生成器的损失LG:

在这里插入图片描述

这里可以使用不同的损失函数,如非饱和损失[11],其中 f D ( x ) = f G ( x ) = l o g ( 1 + e x ) f_D(x)=f_G(x)=log (1 + e^x) fD(x)=fG(x)=log(1+ex),以及hinge损失[28],其中 f D ( x ) = m a x ( 0 , 1 + x ) , f G ( x ) = x f_D(x)=max(0, 1 + x),f_G(x)= x fD(x)=max(0,1+x)fG(x)=x

尽管在更好的GAN架构和损失函数方面做了大量的工作,但仍然存在一个基本的挑战:随着训练的进行,判别器倾向于记忆观察结果。过度拟合的判别器会惩罚除准确的训练数据点以外的任何生成的样本,由于泛化能力差而提供无信息的梯度,并且通常会导致训练的不稳定性

Challenge: Discriminator Overfitting. 本文分析了BigGAN[2]在CIFAR-10上不同数据量的性能。如图1所示,即使给定100%的数据,鉴别器的训练和验证精度之间的差距仍在不断增加,这表明鉴别器只是简单地记忆训练图像。正如Brock等人[2]所观察到的,这种情况不仅发生在有限的数据上,也发生在大规模的ImageNet数据集上。BigGAN已经采用了光谱归一化[28],这是一种广泛应用于生成器和鉴别器体系结构的正则化技术,但仍然存在严重的过拟合问题

3.1 Revisiting Data Augmentation

在许多识别任务中,数据增强是一种常用的减少过拟合的策略——它具有不可替代的作用,也可以与其他正则化技术结合使用:例如权重衰减。我们已经证明了该鉴别器会遇到与二值分类器类似的过拟合问题。然而,相比于对判别器的明确正则化,GAN文献中很少使用数据增广[12,27,28]。事实上,[50]最近的一项研究发现,直接将数据增强应用于gan并不会改善基线。所以,我们想问的问题是:是什么阻止了我们简单地将数据增强应用到gan上?为什么增强gan不如增强分类器有效?

Augment reals only. 增广GANs最直接的方法是直接将增广T应用到真实观测值x上,我们称之为“仅增广真实数据”:

在这里插入图片描述

然而,"仅增强真实数据"偏离了生成式建模的初衷,因为模型现在学习的是T(x)而不是x的不同数据分布。这使得我们无法应用任何明显改变真实图像分布的增强。满足这一要求的选择,尽管强烈依赖于特定的数据集,但在大多数情况下只能是水平翻转。我们发现,应用随机水平翻转确实适度地提高了性能,我们在所有实验中都使用了它,以增强我们的基线。

我们在表1和图5a中分别从数量上和质量上证明了强制执行更强增强的副作用。正如预期的那样,模型学会了产生不需要的颜色和几何失真(例如,不自然的颜色,cutout holes),由这些增强引入,导致性能明显下降(见表1的 “Augment reals only”)。

在这里插入图片描述

图4:DiffAugment用于更新D(左)和G(右)的概述。DiffAugment对真实样本x和生成的输出G(z)均应用增广T。当我们更新G时,梯度需要通过T反向传播,这要求T相对于输入是可微的。

在这里插入图片描述

表1: 在CIFAR-10上用100%的训练数据进行DiffAugment与vanilla增强策略的比较。"Augment reals only "仅对(i)进行增强(见图4),对应公式(3)-(4);"仅增强D "对实数(i)和假数(ii)都进行增强,但不对G(iii)进行增强,对应公式(5)-(6);"DiffAugment "对实数(i)、假数(ii)和G(iii)进行增强。(iii)要求T是可微的,因为梯度应该通过T反向传播到G。 DiffAugment对应于公式(7)-(8)。IS和FID是用10k个样本测量的;验证集是参考分布。我们为每种方法选择具有最佳FID的快照。结果是5次评估运行的平均值;所有标准偏差相对小于1%。

Augment D only. 以前,“Augment reals only”是对真实样本进行one-sided augmentation,因此只有生成的分布与操纵的真实分布相匹配才能实现收敛。从鉴别器的角度来看,当我们更新D时,增加真实和虚假样本可能很诱人:

在这里插入图片描述

在这里,相同的函数T同时应用于真实样本x和虚假样本G(z)。如果生成器成功地模拟了x的分布,那么T(G(z))和T(x)应该和G(z)和x一样,对鉴别器来说是不可区分的。然而,这种策略导致了更糟糕的结果(见表1中的 “仅增强D”)。图5b显示了应用翻译的 "仅增强D "的训练动态。

图5b显示了应用翻译的 "仅增强D "的训练动态。尽管D对增强的图像(T(G(z))和T(x))进行了完美的分类,准确率在90%以上,但它未能识别G(z),即没有增强的生成图像,准确率低于10%。因此,生成器通过G(z)完全欺骗了鉴别器,无法从鉴别器获得有用的信息。这表明,任何打破生成器G和鉴别器D之间微妙平衡的尝试都容易失败

3.2 Differentiable Augmentation for GANs

"Augment reals only "的失败促使我们同时增强实数和假数样本,而 "只增强D "的失败则警告我们,生成器不应该忽略增强的样本。因此,为了通过增强的样本向G传播梯度,增强T必须是可微的,如图4所描述。我们称之为可微分增强(DiffAugment)。

在这里插入图片描述

请注意,T被要求是相同的(随机)函数,但不一定是图4中说明的三个地方的相同随机种子。在本文中,我们用三种简单的变换方式及其组成来证明DiffAugment的有效性。平移(在图像大小的[-1/8, 1/8]范围内,用零填充),cutout[8](用图像大小一半的随机正方形进行遮挡),以及颜色(包括[-0.5, 0.5]范围内的随机亮度,[0.5, 1.5]范围内的对比度,以及[0, 2]范围内的饱和度)。

如表1所示,BigGAN可以使用简单的Translation策略来改进,并使用Cutout和Translation的组合来进一步提升;当Color被组合使用时,它对最强的策略也是稳健的。图6分析了较强的DiffAugment策略通常以较低的训练精度为代价保持较高的判别器验证精度,缓解过拟合问题,并最终实现更好的收敛。

在这里插入图片描述

图6:使用100%训练数据对CIFAR-10进行不同类型的DiffAugment分析。一个更强的DiffAugment可以显著减小鉴别器训练精度之间的差距(中间)和验证精度(右边),导致更好的收敛(左边)。

4 Experiments

4.1 ImageNet

我们在128×128分辨率的ImageNet数据集上采用了表现最好的模型BigGAN[2]。此外,我们用随机水平翻转来增强真实图像,产生了我们所知的BigGAN的最佳重新实现(FID:我们的7.6 vs. 原始论文[2]的8.7)。在所有的数据百分比设置中,我们使用了简单的Translation DiffAugment。在表2中,我们的方法取得了明显的收益,特别是在25%的数据设置下,基线模型经历了早期的崩溃,并在100%的数据可用时,推进了最先进的FID和IS

4.2 FFHQ and LSUN-Cat

在这里插入图片描述

表3:1k、5k、10k和30k训练样本的FFHQ和LSUN-Cat结果。在固定的Color + Translation + Cutout增强的情况下,我们的方法改善了StyleGAN2的基线,并与同时进行的工作ADA[16]相当。FID是用5万个生成的样本测量的;完整的训练集被用来作为参考分布。我们为每种方法选择了具有最佳FID的快照。结果是5次评估运行的平均值;所有的标准偏差都相对小于1%。

我们在FFHQ肖像数据集[17]和LSUN-Cat数据集[46]上以256×256的分辨率进一步实验StyleGAN2[18]。我们研究了不同的有限数据设置,有1k、5k、10k和30k的训练图像可用。我们对所有StyleGAN2基线应用了最强的Color + Translation + Cutout增强,而没有任何超参数变化。真实的图像也像StyleGAN2[18]中通常应用的那样,用随机水平翻转来增强。结果显示在表3中。在所有的数据百分比设置下,我们的性能提升是相当大的。此外,在DiffAugment中使用的固定策略下,我们的性能与ADA[16]相当,后者是一个基于自适应增强策略的并行工作

4.3 CIFAR-10 and CIFAR-100

我们在类别条件的BigGAN[2]和CR-BigGAN[50]以及无条件的StyleGAN2[18]模型上进行实验。为了进行公平的比较,我们还对所有基线的真实图像进行了随机水平翻转的增强。基线模型已经采用了先进的正则化技术,包括光谱归一化[28]、一致性正则化[50]和R1正则化[27];然而,在10%的数据设置下,它们都没有达到令人满意的结果

在这里插入图片描述

表4:CIFAR-10和CIFAR-100的结果。我们为每种方法选择具有最佳FID的快照。结果是5次评估运行的平均值;所有的标准偏差都相对小于1%。我们使用10k个样本和验证集作为FID计算的参考分布,就像之前的工作[50]那样。同时进行的工作[14, 16]使用不同的协议。50k样本和训练集作为参考分布。如果我们采用这种评估协议,我们的BigGAN + DiffAugment实现了4.61的FID,CR-BigGAN + DiffAugment实现了4.30的FID,而StyleGAN2 + DiffAugment实现了5.79的FID。

对于DiffAugment, BigGAN模型采用Translation + Cutout, StyleGAN2采用Color + Cutout, StyleGAN2采用Color + Translation + Cutout, StyleGAN2采用10%或20%的数据。如表4所示,我们的方法在不改变任何超参数的情况下,独立于基线架构、正则化和损失函数(BigGAN中的hinge损耗和StyleGAN2中的non-saturating loss)改进了所有基线。我们请读者参阅附录(表6-7),了解IS的完整表格。这些改进是相当可观的,特别是在可用数据有限的情况下。

4.4 Low-Shot Generation

对于某个人、某个物体或某个地标,如果不是完全不可能的话,收集一个大规模的数据集往往是很乏味的。为了解决这个问题,研究人员最近在图像生成的环境中利用了few-shot learning[9, 21]。Wang等人[45]利用微调来转移在外部大规模数据集上预训练的模型的知识。一些作品提出只对模型的一部分进行微调[30,31,44]。下面,我们表明,我们的方法不仅可以在不使用外部数据集或模型的情况下产生有竞争力的结果,而且与现有的迁移学习方法是正交的。

我们使用与Mo等人[30]相同的代码库在他们的数据集(有160只猫和389只狗的AnimalFace[37])上复制了最近的迁移学习算法[30, 31, 44, 45],基于FFHQ人脸数据集[17]中预训练的StyleGAN模型为了进一步证明数据的效率,我们收集了100张奥巴马、grumpy cat和熊猫的数据集,并在每个数据集上只使用100张图片进行StyleGAN2模型的训练,而没有进行预训练

在这里插入图片描述

表5:Low-shot生成结果。只有100张(奥巴马、不爽猫、熊猫)、160张(猫)或389张(狗)训练图像,我们的方法与迁移学习算法相当,后者预先训练了7万张图像。使用5k生成的样品进行FID测定;训练集是参考分布。我们选择每个方法的最佳FID快照。

对于DiffAugment,我们对StyleGAN2采用了Color + Translation + Cutout,对vanilla微调算法TransferGAN[45]和FreezeD[30]采用了Color + Cutout,该算法冻结了判别器的前几层。表5显示,DiffAugment在所有数据集上取得了独立于训练算法的一致收益。在没有任何预训练的情况下,我们仍然取得了与需要数万张图片的现有转移学习算法相当的结果,但在100张奥巴马数据集上是个例外,在那里用人脸进行预训练显然会导致更好的泛化。

见图3和附录(图18-22)中的定性比较。虽然人们可能会担心生成器可能会过度拟合微小的数据集(即生成相同的训练图像),但图7表明我们的方法通过风格空间的线性插值几乎没有过度拟合[17](也见附录中的图15);请参考附录(图16-17)中的最近邻测试。

4.5 Analysis

下面,我们将研究更小的模型或更强的正则化是否同样可以减少过拟合,以及DiffAugment是否仍然有帮助。最后,分析了DiffAugment的附加选择

Model Size Matters? 如图8a所示,当使用完整的模型时,基线在CIFAR-10上的训练数据为10%时严重过拟合,在1/4个通道时达到最小FID为29.02。然而,在所有的模型容量上,它都被我们的方法超越了。在1/4通道时,我们的模型实现了明显更好的FID,为21.57,而随着模型的增大,差距呈单调增长。我们请读者参阅附录(图11)中的IS图

在这里插入图片描述

图8:使用10%训练数据对CIFAR-10进行小模型或强正则化分析。(a)较小的模型减少了对BigGAN基线的过拟合,而我们的方法在所有模型容量上都占据了优势。(b)在基线StyleGAN2的R1正则化γ的宽扫描上,其最佳FID(26.87)仍然比我们的(14.50)差得多

Stronger Regularization Matters? 由于StyleGAN2采用**R1正则化[27]**来稳定训练,我们将其强度从γ=0.1提高到 1 0 4 10^4 104,并在图8b中绘制了FID曲线。虽然我们最初发现γ=0.1在100%的数据设置下效果最好,但在10%的数据设置下,选择γ= 1 0 3 10^3 103将其性能从34.05提高到26.87。当γ= 1 0 4 10^4 104时,在75万次迭代中,我们只观察到44万次迭代时的最小FID为29.14,此后的性能不断恶化。然而,它的最佳FID仍然比我们的差1.8倍(默认的γ=0.1)。这表明DiffAugment与明确的正则化判别器相比更加有效

在这里插入图片描述

图9:不同类型的DiffAugment始终强于基线。我们用10%的训练数据报告StyleGAN2在CIFAR-10上的FID。

Choice of DiffAugment Matters? 我们在图9中研究了DiffAugment的其他选择,包括随机90°旋转({-90°,0°,90°}各1/3的概率),高斯噪声(标准偏差为0。 1),以及涉及双线性插值的一般几何变换,如双线性平移(在[-0.25,0.25]内)、双线性缩放(在[0.75,1.25]内)、双线性旋转(在[-30°,30°]内)和双线性剪切(在[-0.25,0.25]内)。虽然所有这些政策的表现一直优于基线,但我们发现Color + Translation + Cutout DiffAugment特别有效。这种简单性也使得它更容易部署

猜你喜欢

转载自blog.csdn.net/weixin_37958272/article/details/119788412