★ Ensembling Off-the-shelf Models for GAN Training(GAN模型迎来预训练时代,仅需1%的训练样本)

Ensembling Off-the-shelf Models for GAN Training

集成现成的GAN训练模型

论文链接:https://arxiv.org/abs/2112.09130

项目链接:https://github.com/nupurkmr9/vision-aided-gan

视频链接:https://www.youtube.com/watch?v=oHdyJNdQ9E4


一、解决了什么问题?

每次GAN模型都要从头训练的日子过去了!最近CMU联手Adobe提出了一种新的模型集成策略,让GAN模型也能用上预训练,成功解决「判别器过拟合(训练集性能很强,但在验证集上表现得很差)」这个老大难问题。最近来自CMU和Adobe的研究人员在CVPR 2022发表了一篇文章,通过「选拔」的方式将预训练模型与GAN模型的训练相结合。

1、解决过拟合

在数据量十分有限的情况下,直接上大规模预训练模型作为判别器,非常容易导致生成器被「无情碾压」,然后就「过拟合」了。通过在FFHQ 1k数据集上的实验来看,即使采用最新的可微分数据增强方法,判别器仍然会过拟合(训练集性能很强,但在验证集上表现得很差)。

Training and validation accuracy w.r.t. training iterations for our DINO [11] based discriminator vs. baseline StyleGAN2-ADA discriminator on FFHQ 1k dataset.

 Figure 3. Our discriminator based on pretrained features has higher accuracy on validation real images and thus shows better generalization. In the above training, vision aided adversarial loss is added at the 2M iteration.

2、用1%的数据量,却达到一样的FID

在LSUN CAT 和LSUN CHURCH数据集上进行测试“StyleGAN2-ADA、 DiffAugment”和我们提出的模型,发现在小数据集上,我们提出的模型表现会远远超过当前最好的模型(StyleGAN2-ADA、 DiffAugment”),并且在LSUN CAT数据集上,用了百分之一的数据量,却达到了仅仅 0.7%的性能差异,所以提出的模型表现太好了!

 Figure 2. Performance on LSUN CAT and LSUN CHURCH. We compare with the leading methods StyleGAN2-ADA [41] and DiffAugment [109] on different sizes of training samples and full-dataset. Our method outperforms them by a large margin, especially in limited sample setting. For LSUN CAT we achieve similar FID as StyleGAN2 [44] trained on full-dataset using only 0.7% of the dataset.

二、怎么实现?

方法很简单,就是有一堆已经训好的vision模型,我把他们一个个地freeze backbone之后接一个linear prob来来判断图像是生成的还是真实的,然后每次挑选分得最好的topK个模型并与原始的鉴别器D拼接在一起用来做Discriminator。

Vision-aided GAN training.

 Figure 1. The model bank F consists of widely used and state-of-the-art pretrained networks. We automatically select a subset \{\hat{F}\}_{k=1}^{K} from F, which can best distinguish between real and fake distribution.
Our training procedure consists of creating an ensemble of the original discriminator D and discriminators \hat{D}_{k}=\hat{C}_{k} \circ \hat{F}_{k} based on the feature space of selected off-the-shelf models.
\hat{C}_{k} is a shallow trainable network over the frozen pretrained features.

这种方法有两个好处:

1、在预训练的特征上训练一个浅层分类器\hat{C}_{k})是使深度网络适应小规模数据集的常见方法,同时可以减少过拟合。

也就是说只要把预训练模型的参数固定住,再在顶层加入轻量级的分类网络就可以提供稳定的训练过程。

比如上面实验中的Ours曲线,可以看到验证集的准确率相比StyleGAN2-ADA要提升不少。

2、最近也有一些研究证明了,深度网络可以捕获有意义的视觉概念,从低级别的视觉线索(边缘和纹理)到高级别的概念(物体和物体部分)都能捕获。

建立在这些特征上的判别器可能更符合人类的感知能力。

并且将多个预训练模型组合在一起后,可以促进生成器在不同的、互补的特征空间中匹配真实的分布。

为了选择效果最好的预训练网络,研究人员首先搜集了多个sota模型组成一个「模型bank」,包括用于分类的VGG-16,用于检测和分割的Swin-T等。

 然后基于特征空间中真实和虚假图像的线性分割,提出一个自动的模型搜索策略,并使用标签平滑可微分的增强技术来进一步稳定模型训练,减少过拟合。具体来说,就是将真实训练样本和生成的图像的并集分成训练集和验证集。

对于每个预训练的模型,训练一个逻辑线性判别器来分类样本是来自真实样本还是生成的,并在验证分割上使用「负二元交叉熵损失」测量分布差距,并返回误差最小的模型。

一个较低的验证误差与更高的线性探测精度相关,表明这些特征对于区分真实样本和生成的样本是有用的,使用这些特征可以为生成器提供更有用的反馈。

研究人员我们用FFHQ和LSUN CAT数据集的1000个训练样本对GAN训练进行了经验验证。

Figure 4. Model selection using linear probing of pretrained features. We show correlation of FID with the accuracy of a logistic linear model trained for real vs fake classification over the features of off-the-shelf models. Top dotted line is the FID of StyleGAN2- ADA generator used in model selection and from which we finetune with our proposed vision-aided adversarial loss. Similar analysis for LSUN CAT is shown in Figure 12 in the appendix.

 结果显示,用预训练模型训练的GAN具有更高的线性探测精度,一般来说,可以实现更好的FID指标。

为了纳入多个现成模型的反馈,文中还探索了两种模型选择和集成策略

1)K-fixed模型选择策略,在训练开始时选择K个最好的现成模型并训练直到收敛;

2)K-progressive模型选择策略,在固定的迭代次数后迭代选择并添加性能最佳且未使用的模型。

实验结果可以发现,与K-fixed策略相比,progressive的方式具有更低的计算复杂度,也有助于选择预训练的模型,从而捕捉到数据分布的不同。例如,通过progressive策略选择的前两个模型通常是一对自监督和监督模型。文章中的实验主要以progressive为主。

最终的训练算法首先训练一个具有标准对抗性损失的GAN。

给定一个基线生成器,可以使用线性探测搜索到最好的预训练模型,并在训练中引入损失目标函数。

在K-progressive策略中,在训练了与可用的真实训练样本数量成比例的固定迭代次数后,把一个新的视觉辅助判别器被添加到前一阶段具有最佳训练集FID的快照中。

在训练过程中,通过水平翻转进行数据增强,并使用可微分的增强技术和单侧标签平滑作为正则化项。

还可以观察到,只使用现成的模型作为判别器会导致散度(divergence),而原始判别器和预训练模型的组合则可以改善这一情况

最终实验展示了在FFHQ、LSUN CAT和LSUN CHURCH数据集的训练样本从1k到10k变化时的结果。

Table 1. FFHQ and LSUN results with varying training samples from 1k to 10k. FID↓ is measured with complete dataset as reference distribution. We select the best snapshot according to training set FID, and report mean of 3 FID evaluations. In Ours (w/ ADA) we finetune the StyleGAN2-ADA model, and in Ours (w/ DiffAugment) we finetune the model trained with DiffAgument while using the corresponding policy for augmentation. Our method works with both ADA and DiffAugment strategy for augmenting images input to the discriminators.

在所有设置中,FID都能获得显著提升,证明了该方法在有限数据场景中的有效性。

提高最差样本的质量

为了定性分析该方法和StyleGAN2-ADA之间的差异,根据两个方法生成的样本质量来看,文中提出的新方法能够提高最差样本的质量,特别是对于FFHQ和LSUN CAT

Figure 8. Qualitative comparison of our method with StyleGAN2-ADA on AFHQ.
Left: randomly generated samples for both methods.
Right: For both our model and StyleGAN2-ADA, we independently generate 5k samples and find the worst-case samples compared to real image distribution. We first fit a Gaussian model using the Inception [86] feature space of real images. We then calculate the log-likelihood of each sample given this Gaussian prior and show the images with minimum log-likelihood (maximum Mahalanobis distance). 

当我们逐步增加判别器时,可以看到线性探测对预训练模型的特征的准确性在逐渐下降,也就是说生成器更强了。

 Figure 6. Linear probe accuracy of off-the-shelf models during our K-progressive ensemble training on FFHQ 1k. For the StyleGAN2-ADA, ViT (DINO) model has the highest accuracy and is selected first, then ViT (CLIP) and then Swin-T (MoBY). As we train with vision-aided discriminators, linear probe accuracy decreases for most of the pretrained models.

        

总的来说,在只有1万个训练样本的情况下,该方法在LSUN CAT上的FID与在160万张图像上训练的StyleGAN2性能差不多。

 Table 6. Additional ablation studies evaluated on FID↓ metric. Having two discriminators during training (frozen with random weights or trainable) or standard adversarial training for more iterations leads to only marginal benefits in FID. Thus the improvement is through an ensemble of original and vision-aided discriminators. ✗ means FID increased to twice the baseline, and therefore, we stop the training run.

而在完整的数据集上,该方法在LSUN的猫、教堂和马的类别上提高了1.5到2倍的FID。

CMU联手Adobe:GAN模型迎来预训练时代,仅需1%的训练样本-51CTO.COM

猜你喜欢

转载自blog.csdn.net/weixin_43135178/article/details/127134648