Hacia la comprensión del aumento de datos generativos, NeurIPS 2023
Documento: https://arxiv.org/abs/2305.17476
Código: https://github.com/ML-GSAI/Understanding-GDA
Interpretación y reimpresión: [NeurIPS 2023] Hacia la comprensión del aumento de datos generativos - Zhihu (zhihu.com)
Descripción general
El aumento de datos generativos expande el conjunto de datos generando nuevas muestras a través de modelos generativos condicionales, mejorando así el rendimiento de la clasificación en diversas tareas de aprendizaje. Sin embargo, pocos estudios han investigado teóricamente los efectos del aumento de datos generativos. Para llenar este vacío, este artículo construye un límite de error de generalización general basado en la estabilidad en este entorno no IID. Con base en el límite de generalización general, el artículo explora más a fondo la situación de aprendizaje del modelo de mezcla gaussiana y la red generativa adversaria. En ambos casos, el artículo demuestra que, aunque el aumento de datos generativos no disfruta de tasas de aprendizaje más rápidas, puede mejorar las garantías de aprendizaje a un nivel constante cuando el conjunto de entrenamiento es pequeño, que es el caso cuando se produce un sobreajuste. Finalmente, los resultados de la simulación del modelo de mezcla gaussiana y los resultados experimentales de la red generativa adversaria respaldan las conclusiones teóricas del artículo.
conclusión principal
Definición de símbolo
Aumento de datos generativos
generalmente
Podemos derivar el siguiente error de generalización para cualquier generador y clasificador consistente y estable:
En términos generales, nos preocupa más la tasa de convergencia del límite del error de generalización con respecto al número de muestras. se considerarán hiperparámetros y los dos últimos elementos se registrarán como error de generalización con distribución mixta, y se podrá definir el "número más efectivo de mejoras":
En este escenario, y comparándolo con el caso sin aumento de datos ( ), podemos obtener las siguientes condiciones suficientes, que caracterizan cuándo el aumento de datos generativos puede (no) facilitar las tareas de clasificación posteriores, lo cual es diferente del aprendizaje de modelos generativos. relacionado con:
Modelo de mezcla gaussiana
Para verificar la exactitud de nuestra teoría, primero consideramos el establecimiento de un modelo de mezcla gaussiano simple.
Calculamos específicamente cada elemento del Teorema 3.1 en el escenario del modelo de mezcla gaussiana y podemos deducir
- Cuando la cantidad de datos es suficiente, es difícil que el aumento de datos generativos mejore el rendimiento de clasificación de las tareas posteriores, incluso si adoptamos la "cantidad de aumento más efectiva".
- Cuando la cantidad de datos es pequeña, el error de generalización está dominado por dimensiones y otros elementos. En este momento, la mejora de datos generativos puede reducir el error de generalización a un nivel constante, lo que significa que en escenarios de sobreajuste, la mejora de datos generativos es muy necesario.
Red de confrontación generativa
También consideramos el caso del aprendizaje profundo. Suponemos que el modelo generativo es una red generativa adversaria MLP y el clasificador es un MLP de capa L o CNN. La función de pérdida es entropía cruzada binaria y el algoritmo de optimización es SGD. Suponemos que la función de pérdida es fluida y que los parámetros de la red neuronal de la capa l pueden controlarse mediante ‖ ‖. Podemos sacar las siguientes conclusiones:
- Cuando la cantidad de datos es suficiente, el aumento de datos generativos también es difícil de mejorar el rendimiento de clasificación de las tareas posteriores, e incluso puede empeorarlo.
- Cuando la cantidad de datos es pequeña, el error de generalización está dominado por otros elementos como las dimensiones. En este momento, la mejora de los datos generativos puede reducir el error de generalización a un nivel constante. De manera similar, esto significa que en el escenario de sobreajuste, la La mejora de los datos generados es necesaria.
experimento
Experimento de simulación del modelo de mezcla gaussiana
Verificamos nuestra teoría en una mezcla de distribuciones gaussianas, donde ajustamos el tamaño de los datos , la dimensión de los datos d y . Los resultados experimentales se muestran en la siguiente figura:
- Al observar la Figura (a), podemos encontrar que cuando d es lo suficientemente grande, la introducción del aumento de datos generativos no cambia significativamente el error de generalización.
- Al observar la Figura (d), podemos encontrar que cuando se corrige, el error de generalización real es de orden O (d), y a medida que aumenta el número de mejoras, el error de generalización disminuye a un nivel constante.
- En las otras cuatro imágenes, seleccionamos dos situaciones para verificar que nuestro límite puede predecir el error de generalización hasta cierto punto en la tendencia.
Experimento de modelo generativo profundo
Utilizamos ResNet como clasificador, cDCGAN, StyleGANv2-ADA y EDM como modelos generativos profundos y realizamos experimentos en el conjunto de datos CIFAR-10. Los resultados experimentales se muestran a continuación. Dado que los errores de entrenamiento en el conjunto de entrenamiento son cercanos a 0, la tasa de error en el conjunto de prueba es una mejor estimación del error de generalización. Calculamos la idoneidad en función de si realizamos un aumento de datos adicional (volteando, etc.) .