【NeurIPS 2023】Toward Understanding Generative Data Augmentation

Toward Understanding Generative Data Augmentation, NeurIPS 2023

论文:https://arxiv.org/abs/2305.17476

代码:https://github.com/ML-GSAI/Understanding-GDA

解读转载:[NeurIPS 2023] Toward Understanding Generative Data Augmentation - 知乎 (zhihu.com)

概述

生成式数据扩增通过条件生成模型生成新样本来扩展数据集,从而提高各种学习任务的分类性能。然而,很少有人从理论上研究生成数据增强的效果。为了填补这一空白,论文在这种非独立同分布环境下构建了基于稳定性的通用泛化误差界。基于通用的泛化界,论文进一步了探究了高斯混合模型和生成对抗网络的学习情况。在这两种情况下,论文证明了,虽然生成式数据增强并不能享受更快的学习率,但当训练集较小时,它可以在一个常数的水平上提高学习保证,这在发生过拟合时是非常重要的。最后,高斯混合模型的仿真结果和生成式对抗网络的实验结果都支持论文的理论结论。

主要结论

符号定义 

生成式数据增强

一般情况

我们可以对于任意的生成器和一致\beta _m稳定的分类器,推得如下的泛化误差:

一般来说,我们比较关心泛化误差界关于样本数m_S的收敛率。将m_G看成超参数,并将后面两项记为generalization error w.r.t. mixed distribution,可以定义“最有效的增强数量”:

在这个设置下,并和没有数据增强的情况进行对比(m_G=0),我们可以得到如下的充分条件,它刻画了生成式数据增强何时(不)能够促进下游分类任务,这和生成模型学习分的能力息息相关:

高斯混合模型

为了验证我们理论的正确性,我们先考虑了一个简单的高斯混合模型的setting。

我们在高斯混合模型的场景下具体计算Theorem 3.1中的各个项,可以推得

  1. 当数据量m_S足够时,即使我们采用“最有效的增强数量”,生成式数据增强也难以提高下游任务的分类性能。
  2. 当数据量m_S较小的,此时主导泛化误差的是维度等其他项,此时进行生成式数据增强可以常数级降低泛化误差,这意味着在过拟合的场景下,生成式数据增强是很有必要的。

生成对抗网络

我们也考虑了深度学习的情况。我们假设生成模型为MLP生成对抗网络,分类器为L层MLP或者CNN。损失函数为二元交叉熵,优化算法为SGD。我们假设损失函数平滑,并且第l层的神经网络参数可以被‖W_l‖控制。我们可以推得如下的结论:

  1. 当数据量m_S足够时,生成式数据增强也难以提高下游任务的分类性能,甚至会恶化。
  2. 当数据量m_S较小的,此时主导泛化误差的是维度等其他项,此时进行生成式数据增强可以常数级降低泛化误差,同样地,这意味着在过拟合的场景下,生成式数据增强是很有必要的。

实验

高斯混合模型模拟实验

我们在混合高斯分布上验证我们的理论,我们调整数据量m_S,数据维度d以及m_G=\gamma m_S。实验结果如下图所示:

  1. 观察图(a),我们可以发现当m_S相对于d足够大的时候,生成式数据增强的引入并不能明显改变泛化误差。
  2. 观察图(d),我们可以发现当m_S固定时,真实的泛化误差确实是O(d)阶的,且随着增强数量\gamma的增大,泛化误差呈现常数级的降低。
  3. 另外4张图,我们选取了两种情况,验证了我们的bound能在趋势上一定程度上预测泛化误差。

深度生成模型实验

我们使用ResNet作为分类器,cDCGAN、StyleGANv2-ADA和EDM作为深度生成模型,在CIFAR-10数据集上进行了实验。实验结果如下所示。由于训练集上训练误差都接近0,所以测试集上的错误率是泛化误差的一个比较好的估计。我们利用是否做额外的数据增强(翻转等)来近似m_S是否充足。

Method Repository
CDCGAN GitHub - znxlwm/pytorch-MNIST-CelebA-cGAN-cDCGAN: Pytorch implementation of conditional Generative Adversarial Networks (cGAN) and conditional Deep Convolutional Generative Adversarial Networks (cDCGAN) for MNIST dataset
StyleGAN GitHub - NVlabs/stylegan2-ada-pytorch: StyleGAN2-ADA - Official PyTorch implementation
EDM data & training GitHub - wzekai99/DM-Improves-AT: Code for the paper "Better Diffusion Models Further Improve Adversarial Training" (ICML 2023)
bGMM GitHub - ML-GSAI/Revisiting-Dis-vs-Gen-Classifiers: Official implementation for "Revisiting Discriminative vs. Generative Classifiers: Theory and Implications".

猜你喜欢

转载自blog.csdn.net/m0_61899108/article/details/133521671
今日推荐