【NeurIPS 2023】 Hacia la comprensión del aumento de datos generativos

Hacia la comprensión del aumento de datos generativos, NeurIPS 2023

Documento: https://arxiv.org/abs/2305.17476

Código: https://github.com/ML-GSAI/Understanding-GDA

Interpretación y reimpresión: [NeurIPS 2023] Hacia la comprensión del aumento de datos generativos - Zhihu (zhihu.com)

Descripción general

El aumento de datos generativos expande el conjunto de datos generando nuevas muestras a través de modelos generativos condicionales, mejorando así el rendimiento de la clasificación en diversas tareas de aprendizaje. Sin embargo, pocos estudios han investigado teóricamente los efectos del aumento de datos generativos. Para llenar este vacío, este artículo construye un límite de error de generalización general basado en la estabilidad en este entorno no IID. Con base en el límite de generalización general, el artículo explora más a fondo la situación de aprendizaje del modelo de mezcla gaussiana y la red generativa adversaria. En ambos casos, el artículo demuestra que, aunque el aumento de datos generativos no disfruta de tasas de aprendizaje más rápidas, puede mejorar las garantías de aprendizaje a un nivel constante cuando el conjunto de entrenamiento es pequeño, que es el caso cuando se produce un sobreajuste. Finalmente, los resultados de la simulación del modelo de mezcla gaussiana y los resultados experimentales de la red generativa adversaria respaldan las conclusiones teóricas del artículo.

conclusión principal

Definición de símbolo 

Aumento de datos generativos

generalmente

Podemos derivar el siguiente error de generalización para cualquier generador y \beta_mclasificador consistente y estable:

En términos generales, nos preocupa más la EMtasa de convergencia del límite del error de generalización con respecto al número de muestras. se m_Gconsiderarán hiperparámetros y los dos últimos elementos se registrarán como error de generalización con distribución mixta, y se podrá definir el "número más efectivo de mejoras":

En este escenario, y comparándolo con el caso sin aumento de datos ( m_G=0), podemos obtener las siguientes condiciones suficientes, que caracterizan cuándo el aumento de datos generativos puede (no) facilitar las tareas de clasificación posteriores, lo cual es diferente del aprendizaje de modelos generativos. relacionado con:

Modelo de mezcla gaussiana

Para verificar la exactitud de nuestra teoría, primero consideramos el establecimiento de un modelo de mezcla gaussiano simple.

Calculamos específicamente cada elemento del Teorema 3.1 en el escenario del modelo de mezcla gaussiana y podemos deducir

  1. Cuando la cantidad de datos EMes suficiente, es difícil que el aumento de datos generativos mejore el rendimiento de clasificación de las tareas posteriores, incluso si adoptamos la "cantidad de aumento más efectiva".
  2. Cuando la cantidad de datos EMes pequeña, el error de generalización está dominado por dimensiones y otros elementos. En este momento, la mejora de datos generativos puede reducir el error de generalización a un nivel constante, lo que significa que en escenarios de sobreajuste, la mejora de datos generativos es muy necesario.

Red de confrontación generativa

También consideramos el caso del aprendizaje profundo. Suponemos que el modelo generativo es una red generativa adversaria MLP y el clasificador es un MLP de capa L o CNN. La función de pérdida es entropía cruzada binaria y el algoritmo de optimización es SGD. Suponemos que la función de pérdida es fluida y que los parámetros de la red neuronal de la capa l pueden W_lcontrolarse mediante ‖ ‖. Podemos sacar las siguientes conclusiones:

  1. Cuando la cantidad de datos EMes suficiente, el aumento de datos generativos también es difícil de mejorar el rendimiento de clasificación de las tareas posteriores, e incluso puede empeorarlo.
  2. Cuando la cantidad de datos EMes pequeña, el error de generalización está dominado por otros elementos como las dimensiones. En este momento, la mejora de los datos generativos puede reducir el error de generalización a un nivel constante. De manera similar, esto significa que en el escenario de sobreajuste, la La mejora de los datos generados es necesaria.

experimento

Experimento de simulación del modelo de mezcla gaussiana

Verificamos nuestra teoría en una mezcla de distribuciones gaussianas, donde ajustamos el tamaño de los datos EM, la dimensión de los datos d y m_G=\gamma m_S. Los resultados experimentales se muestran en la siguiente figura:

  1. Al observar la Figura (a), podemos encontrar que cuando EMd es lo suficientemente grande, la introducción del aumento de datos generativos no cambia significativamente el error de generalización.
  2. Al observar la Figura (d), podemos encontrar que cuando EMse corrige, el error de generalización real es de orden O (d), y a medida que \gamaaumenta el número de mejoras, el error de generalización disminuye a un nivel constante.
  3. En las otras cuatro imágenes, seleccionamos dos situaciones para verificar que nuestro límite puede predecir el error de generalización hasta cierto punto en la tendencia.

Experimento de modelo generativo profundo

Utilizamos ResNet como clasificador, cDCGAN, StyleGANv2-ADA y EDM como modelos generativos profundos y realizamos experimentos en el conjunto de datos CIFAR-10. Los resultados experimentales se muestran a continuación. Dado que los errores de entrenamiento en el conjunto de entrenamiento son cercanos a 0, la tasa de error en el conjunto de prueba es una mejor estimación del error de generalización. EMCalculamos la idoneidad en función de si realizamos un aumento de datos adicional (volteando, etc.) .

Método Repositorio
CDCGAN GitHub - znxlwm/pytorch-MNIST-CelebA-cGAN-cDCGAN: Implementación en Pytorch de Redes Adversarias Generativas condicionales (cGAN) y Redes Adversariales Generativas Convolucionales Profundas (cDCGAN) condicionales para el conjunto de datos MNIST
EstiloGAN GitHub - NVlabs/stylegan2-ada-pytorch: StyleGAN2-ADA - Implementación oficial de PyTorch
Datos y entrenamiento de electroerosión GitHub - wzekai99/DM-Improves-AT: Código para el artículo "Mejores modelos de difusión mejoran aún más el entrenamiento de adversarios" (ICML 2023)
bGMM GitHub - ML-GSAI/Revisiting-Dis-vs-Gen-Classifiers: Implementación oficial de "Revisitando clasificadores discriminativos versus generativos: teoría e implicaciones".

Supongo que te gusta

Origin blog.csdn.net/m0_61899108/article/details/133521671
Recomendado
Clasificación