Rumo à compreensão do aumento de dados generativos, NeurIPS 2023
Artigo: https://arxiv.org/abs/2305.17476
Código: https://github.com/ML-GSAI/Understanding-GDA
Interpretação e reimpressão: [NeurIPS 2023] Rumo à compreensão do aumento de dados generativos - Zhihu (zhihu.com)
Visão geral
O aumento generativo de dados expande o conjunto de dados gerando novas amostras por meio de modelos generativos condicionais, melhorando assim o desempenho da classificação em várias tarefas de aprendizagem. No entanto, poucos estudos investigaram teoricamente os efeitos do aumento generativo de dados. Para preencher esta lacuna, este artigo constrói um erro de generalização geral baseado na estabilidade neste ambiente não-IID. Com base no limite de generalização geral, o artigo explora ainda mais a situação de aprendizagem do modelo de mistura gaussiana e da rede adversária generativa. Em ambos os casos, o artigo demonstra que embora o aumento generativo de dados não desfrute de taxas de aprendizagem mais rápidas, pode melhorar as garantias de aprendizagem a um nível constante quando o conjunto de treino é pequeno, o que é o caso quando ocorre overfitting. O tempo é muito importante. Finalmente, os resultados da simulação do modelo de mistura gaussiana e os resultados experimentais da rede adversária generativa apoiam as conclusões teóricas do artigo.
conclusão principal
Definição de símbolo
Aumento generativo de dados
geralmente
Podemos derivar o seguinte erro de generalização para qualquer gerador e classificador consistente e estável:
De modo geral, estamos mais preocupados com a taxa de convergência do limite do erro de generalização em relação ao número de amostras. serão considerados hiperparâmetros, e os dois últimos itens serão registrados como erro de generalização em relação à distribuição mista, e o "número mais eficaz de melhorias" pode ser definido:
Neste cenário, e comparando-o com o caso sem aumento de dados ( ), podemos obter as seguintes condições suficientes, que caracterizam quando o aumento generativo de dados pode (não) facilitar tarefas de classificação downstream, o que é diferente da aprendizagem de modelo generativo. relacionado a:
Modelo de mistura gaussiana
Para verificar a exatidão de nossa teoria, consideramos primeiro a configuração de um modelo de mistura gaussiana simples.
Calculamos especificamente cada item do Teorema 3.1 no cenário do modelo de mistura gaussiana, e podemos deduzir
- Quando a quantidade de dados é suficiente, o aumento generativo de dados é difícil de melhorar o desempenho de classificação das tarefas downstream, mesmo se adotarmos a "quantidade de aumento mais eficaz".
- Quando a quantidade de dados é pequena, o erro de generalização é dominado por dimensões e outros itens.Neste momento, o aprimoramento generativo de dados pode reduzir o erro de generalização em um nível constante, o que significa que em cenários de sobreajuste, o aprimoramento generativo de dados é Muito necessário.
Rede Adversarial Gerativa
Também consideramos o caso do aprendizado profundo. Assumimos que o modelo generativo é uma rede adversária generativa MLP e o classificador é um MLP ou CNN de camada L. A função de perda é a entropia cruzada binária e o algoritmo de otimização é SGD. Assumimos que a função de perda é suave e os parâmetros da rede neural da camada l podem ser controlados por ‖ ‖. Podemos tirar as seguintes conclusões:
- Quando a quantidade de dados é suficiente, o aumento generativo de dados também é difícil para melhorar o desempenho da classificação das tarefas downstream e pode até piorar.
- Quando a quantidade de dados é pequena, o erro de generalização é dominado por outros itens, como dimensões. Neste momento, o aprimoramento generativo de dados pode reduzir o erro de generalização em um nível constante. Da mesma forma, isso significa que no cenário de sobreajuste, o O aprimoramento dos dados gerados é necessário.
experimentar
Experimento de simulação de modelo de mistura gaussiana
Verificamos nossa teoria em uma mistura de distribuições gaussianas, onde ajustamos o tamanho dos dados , a dimensão dos dados d e . Os resultados experimentais são mostrados na figura abaixo:
- Observando a Figura (a), podemos descobrir que quando d é grande o suficiente, a introdução do aumento generativo de dados não altera significativamente o erro de generalização.
- Observando a Figura (d), podemos descobrir que quando fixo, o erro de generalização real é de fato da ordem O(d), e à medida que o número de melhorias aumenta, o erro de generalização diminui a um nível constante.
- Nas outras quatro imagens, selecionamos duas situações para verificar se nosso limite pode prever até certo ponto o erro de generalização da tendência.
Experimento de modelo generativo profundo
Usamos ResNet como classificador, cDCGAN, StyleGANv2-ADA e EDM como modelos generativos profundos e conduzimos experimentos no conjunto de dados CIFAR-10. Os resultados experimentais são mostrados abaixo. Como os erros de treinamento no conjunto de treinamento são próximos de 0, a taxa de erro no conjunto de teste é uma estimativa melhor do erro de generalização. Aproximamos a adequação dependendo se fazemos aumento adicional de dados (inversão, etc.) .