【NeurIPS 2023】 Rumo à compreensão do aumento de dados generativos

Rumo à compreensão do aumento de dados generativos, NeurIPS 2023

Artigo: https://arxiv.org/abs/2305.17476

Código: https://github.com/ML-GSAI/Understanding-GDA

Interpretação e reimpressão: [NeurIPS 2023] Rumo à compreensão do aumento de dados generativos - Zhihu (zhihu.com)

Visão geral

O aumento generativo de dados expande o conjunto de dados gerando novas amostras por meio de modelos generativos condicionais, melhorando assim o desempenho da classificação em várias tarefas de aprendizagem. No entanto, poucos estudos investigaram teoricamente os efeitos do aumento generativo de dados. Para preencher esta lacuna, este artigo constrói um erro de generalização geral baseado na estabilidade neste ambiente não-IID. Com base no limite de generalização geral, o artigo explora ainda mais a situação de aprendizagem do modelo de mistura gaussiana e da rede adversária generativa. Em ambos os casos, o artigo demonstra que embora o aumento generativo de dados não desfrute de taxas de aprendizagem mais rápidas, pode melhorar as garantias de aprendizagem a um nível constante quando o conjunto de treino é pequeno, o que é o caso quando ocorre overfitting. O tempo é muito importante. Finalmente, os resultados da simulação do modelo de mistura gaussiana e os resultados experimentais da rede adversária generativa apoiam as conclusões teóricas do artigo.

conclusão principal

Definição de símbolo 

Aumento generativo de dados

geralmente

Podemos derivar o seguinte erro de generalização para qualquer gerador e \beta_mclassificador consistente e estável:

De modo geral, estamos mais preocupados com a EMtaxa de convergência do limite do erro de generalização em relação ao número de amostras. serão m_Gconsiderados hiperparâmetros, e os dois últimos itens serão registrados como erro de generalização em relação à distribuição mista, e o "número mais eficaz de melhorias" pode ser definido:

Neste cenário, e comparando-o com o caso sem aumento de dados ( m_G=0), podemos obter as seguintes condições suficientes, que caracterizam quando o aumento generativo de dados pode (não) facilitar tarefas de classificação downstream, o que é diferente da aprendizagem de modelo generativo. relacionado a:

Modelo de mistura gaussiana

Para verificar a exatidão de nossa teoria, consideramos primeiro a configuração de um modelo de mistura gaussiana simples.

Calculamos especificamente cada item do Teorema 3.1 no cenário do modelo de mistura gaussiana, e podemos deduzir

  1. Quando a quantidade de dados EMé suficiente, o aumento generativo de dados é difícil de melhorar o desempenho de classificação das tarefas downstream, mesmo se adotarmos a "quantidade de aumento mais eficaz".
  2. Quando a quantidade de dados EMé pequena, o erro de generalização é dominado por dimensões e outros itens.Neste momento, o aprimoramento generativo de dados pode reduzir o erro de generalização em um nível constante, o que significa que em cenários de sobreajuste, o aprimoramento generativo de dados é Muito necessário.

Rede Adversarial Gerativa

Também consideramos o caso do aprendizado profundo. Assumimos que o modelo generativo é uma rede adversária generativa MLP e o classificador é um MLP ou CNN de camada L. A função de perda é a entropia cruzada binária e o algoritmo de otimização é SGD. Assumimos que a função de perda é suave e os parâmetros da rede neural da camada l podem ser W_lcontrolados por ‖ ‖. Podemos tirar as seguintes conclusões:

  1. Quando a quantidade de dados EMé suficiente, o aumento generativo de dados também é difícil para melhorar o desempenho da classificação das tarefas downstream e pode até piorar.
  2. Quando a quantidade de dados EMé pequena, o erro de generalização é dominado por outros itens, como dimensões. Neste momento, o aprimoramento generativo de dados pode reduzir o erro de generalização em um nível constante. Da mesma forma, isso significa que no cenário de sobreajuste, o O aprimoramento dos dados gerados é necessário.

experimentar

Experimento de simulação de modelo de mistura gaussiana

Verificamos nossa teoria em uma mistura de distribuições gaussianas, onde ajustamos o tamanho dos dados EM, a dimensão dos dados d e m_G=\gama m_S. Os resultados experimentais são mostrados na figura abaixo:

  1. Observando a Figura (a), podemos descobrir que quando EMd é grande o suficiente, a introdução do aumento generativo de dados não altera significativamente o erro de generalização.
  2. Observando a Figura (d), podemos descobrir que quando EMfixo, o erro de generalização real é de fato da ordem O(d), e à medida que o número \gamade melhorias aumenta, o erro de generalização diminui a um nível constante.
  3. Nas outras quatro imagens, selecionamos duas situações para verificar se nosso limite pode prever até certo ponto o erro de generalização da tendência.

Experimento de modelo generativo profundo

Usamos ResNet como classificador, cDCGAN, StyleGANv2-ADA e EDM como modelos generativos profundos e conduzimos experimentos no conjunto de dados CIFAR-10. Os resultados experimentais são mostrados abaixo. Como os erros de treinamento no conjunto de treinamento são próximos de 0, a taxa de erro no conjunto de teste é uma estimativa melhor do erro de generalização. EMAproximamos a adequação dependendo se fazemos aumento adicional de dados (inversão, etc.) .

Método Repositório
CDCGAN GitHub - znxlwm/pytorch-MNIST-CelebA-cGAN-cDCGAN: Implementação Pytorch de Redes Adversariais Generativas condicionais (cGAN) e Redes Adversariais Gerativas Convolucionais Profundas condicionais (cDCGAN) para conjunto de dados MNIST
EstiloGAN GitHub - NVlabs/stylegan2-ada-pytorch: StyleGAN2-ADA - Implementação oficial do PyTorch
Dados e treinamento de EDM GitHub - wzekai99/DM-Improves-AT: Código para o artigo "Melhores modelos de difusão melhoram ainda mais o treinamento adversário" (ICML 2023)
bGMM GitHub - ML-GSAI/Revisiting-Dis-vs-Gen-Classifiers: Implementação oficial para "Revisiting Discriminative vs. Generative Classifiers: Theory and Implications".

Acho que você gosta

Origin blog.csdn.net/m0_61899108/article/details/133521671
Recomendado
Clasificación