【NeurIPS 2023】生成データ拡張の理解に向けて

生成データ拡張の理解に向けて、NeurIPS 2023

論文: https://arxiv.org/abs/2305.17476

コード: https://github.com/ML-GSAI/Understanding-GDA

解釈と転載: [NeurIPS 2023] 生成データ拡張の理解に向けて - Zhihu (zhihu.com)

概要

生成データ拡張は、条件付き生成モデルを通じて新しいサンプルを生成することでデータセットを拡張し、それによってさまざまな学習タスクの分類パフォーマンスを向上させます。ただし、生成データ拡張の効果を理論的に調査した研究はほとんどありません。このギャップを埋めるために、この論文では、この非 IID 環境での安定性に基づいた一般化誤差限界を構築します。この論文では、一般化限界に基づいて、混合ガウス モデルと敵対的生成ネットワークの学習状況をさらに調査します。どちらの場合も、この論文は、生成データ拡張では学習速度は速くなりませんが、トレーニング セットが小さい場合、つまり過学習が発生した場合に一定レベルで学習保証を向上させることができることを実証しています。時間が非常に重要です。最後に、混合ガウス モデルのシミュレーション結果と敵対的生成ネットワークの実験結果は、この論文の理論的な結論を裏付けています。

主な結論

シンボルの定義 

生成的なデータの拡張

一般的に

任意のジェネレーターと一貫性のある\beta_m安定した分類子について、次の一般化誤差を導き出すことができます。

MS一般的に言えば、サンプル数に対する汎化誤差限界の収束率のほうが気になります。はハイパーパラメータとみなされm_G、最後の 2 つの項目は混合分布に関する汎化エラーとして記録され、「最も効果的な拡張の数」を定義できます。

この設定では、データ拡張なしの場合 ( m_G=0) と比較すると、生成モデル学習とは異なる、生成データ拡張が下流の分類タスクを容易にできる (できない) 場合を特徴付ける以下の十分条件を得ることができます。関連:

混合ガウスモデル

私たちの理論の正しさを検証するために、最初に単純な混合ガウス モデルの設定を検討しました。

定理 3.1 の各項目を混合ガウス モデルのシナリオで具体的に計算すると、次のように推定できます。

  1. データ量がMS十分である場合、生成的データ拡張では、「最も効果的な拡張量」を採用したとしても、下流タスクの分類パフォーマンスを向上させることは困難です。
  2. データ量MSが少ない場合、汎化誤差は次元などの項目が支配的ですが、このとき、生成データ拡張により汎化誤差を一定レベルで低減できます。とても必要です。

敵対的生成ネットワーク

深層学習の場合も考えます。生成モデルは MLP 敵対的生成ネットワーク、分類器は L 層 MLP または CNN であると仮定します。損失関数はバイナリ クロス エントロピー、最適化アルゴリズムは SGD です。損失関数は滑らかで、層 l のニューラル ネットワーク パラメーターはW_l‖ ‖ によって制御できると仮定します。次の結論を導き出すことができます。

  1. データ量がMS十分な場合、生成データ拡張によって下流タスクの分類パフォーマンスを向上させることも難しく、さらに悪化する可能性があります。
  2. データ量がMS少ない場合、汎化誤差は次元などの他の項目によって支配されますが、このとき生成データ拡張により汎化誤差を一定レベルで低減できます。これは、過学習シナリオでも同様に、生成されたデータの拡張が必要です。

実験

混合ガウスモデルのシミュレーション実験

ガウス分布の混合に関する理論を検証し、データ サイズMS、データ次元 d およびを調整しますm_G=\ガンマ m_S実験結果を以下の図に示します。

  1. MS図 (a) を観察すると、 d が十分に大きい場合、生成データ拡張の導入によって汎化誤差が大きく変化しないことがわかります。
  2. MS図 (d) を観察すると、固定された場合、実際の汎化誤差は確かに O(d) オーダーであり、\ガンマ拡張の数が増加するにつれて汎化誤差は一定のレベルで減少することがわかります。
  3. 他の 4 つの図では、限界が傾向上の汎化誤差をある程度予測できることを検証するために 2 つの状況を選択しました。

深い生成モデルの実験

分類器として ResNet、深層生成モデルとして cDCGAN、StyleGANv2-ADA、EDM を使用し、CIFAR-10 データセットで実験を実施しました。実験結果を以下に示します。トレーニング セットのトレーニング エラーは 0 に近いため、テスト セットのエラー率は汎化誤差のより適切な推定値となります。追加のデータ拡張 (反転など) を行うかどうかによって、MS妥当性を概算します。

方法 リポジトリ
CDCGA GitHub - znxlwm/pytorch-MNIST-CelebA-cGAN-cDCGAN: MNIST データセット用の条件付き敵対的生成ネットワーク (cGAN) と条件付き深層畳み込み敵対的生成ネットワーク (cDCGAN) の Pytorch 実装
スタイルGAN GitHub - NVlabs/stylegan2-ada-pytorch: StyleGAN2-ADA - 公式 PyTorch 実装
EDM データとトレーニング GitHub - wzekai99/DM-改善-AT: 論文「より良い拡散モデルにより敵対的トレーニングがさらに改善される」(ICML 2023) のコード
bGMM GitHub - ML-GSAI/Revisiting-Dis-vs-Gen-Classifiers: 「Revisiting Discriminative vs. Generative Classifiers: Theory and Implications」の公式実装。

おすすめ

転載: blog.csdn.net/m0_61899108/article/details/133521671