Entrene de manera eficiente modelos de aprendizaje profundo en el conjunto de datos CIFAR

Entrene de manera eficiente modelos de aprendizaje profundo en el conjunto de datos CIFAR

  Actualmente, hay muchos modelos preentrenados en el conjunto de datos de ImageNet, pero no muchos modelos preentrenados en el conjunto de datos CIFAR. CIFAR se usa a menudo como un conjunto de datos para verificar ideas innovadoras. Aunque es un conjunto de datos pequeño, diferentes configuraciones de entrenamiento aún conducirán a una precisión del modelo significativamente diferente. Entrenar un modelo que logre la precisión de referencia es un trabajo necesario y que requiere mucho tiempo.

1. Introducción

  Anteriormente, cuando necesitaba personalmente el modelo de preentrenamiento CIFAR, siempre descartaba la configuración de entrenamiento y la estructura del modelo después de usarlo, lo que resultó en rediseñar la estructura y establecer parámetros de entrenamiento la próxima vez que necesitaba entrenar una nueva estructura de modelo, desperdiciando un mucho tiempo. Por supuesto, la configuración de capacitación de CIFAR se puede obtener fácilmente de Github, Gitee y otras plataformas, pero muchos códigos no tienen en cuenta la escalabilidad y, básicamente, no hay límite para la potencia informática. Para resolver el problema del entrenamiento del modelo de conjunto de datos CIFAR de una vez por todas, este documento se basa en los códigos de entrenamiento del modelo de muchos documentos y escribe un marco de entrenamiento del modelo de conjunto de datos CIFAR basado en pytorch (aquí simplemente llamamos marco CMTF ) . En el proceso de escribir el código, descubrí que CMTF puede ser más amigable con los dispositivos GPU de bajo rendimiento.

  CMTF adopta una configuración de entrenamiento simple y eficiente, tiene un estilo de registro claro y es compatible con el entrenamiento de arquitectura VGG (VGG11, 13, 16, 19 y su versión con Batchnormal) y ResNET (ResNet20, 32, 44, 56, 110) . En la actualidad, el punto de control se ha obtenido en los modelos VGG16BN y ResNet20 entrenados en CIFAR10, y la precisión de referencia no es inferior a la de algunos trabajos académicos . Además, CMTF no necesita configurar la computación paralela de múltiples tarjetas y puede agregar nuevas estructuras de modelo y pequeños conjuntos de datos con operaciones simples, lo cual es fácil de expandir.

2. Instalación y operación de CMTF

  CMTF se cargó en github y el proyecto se puede implementar en la plataforma informática local de acuerdo con los siguientes pasos.

  • Primero, necesita descargar el proyecto al local.
    • Para usuarios de Windows, ingrese directamente al sitio web para descargar el proyecto y use el IDE local para abrir el proyecto. Si el usuario no puede ingresar a github, también puede usar Baidu Cloud para descargar :

      Enlace: https://pan.baidu.com/s/1ab1Z1yvhlqWU8pJwGeKcNQ?pwd=07un
      Código de extracción: 07un

    • Para usuarios de Linux, puede usar el método git para descargar el proyecto al servidor de Linux:

      git clone SunYaFeng1996/CifarModelTrainingFramework.git

Cuando el servidor no está conectado a la red externa, también puede ingresar directamente al sitio web para descargar el proyecto y cargarlo en el servidor Linux.

  • Luego configure el entorno de python.

Si los usuarios quieren usar exactamente el mismo entorno, pueden usar conda para crear un nuevo entorno:

cd CMTF所在目录
conda env create -f environment.yaml

Sin embargo, los requisitos del entorno Python de CMTF son muy simples, solo se requieren pytorch y tensorboardX , y se recomienda a los usuarios que instalen estos dos paquetes directamente en el entorno existente.

  • Finalmente, ejecute el programa para entrenar el modelo.
    La ubicación del archivo central del modelo de entrenamiento es ModelTrainingFramework /training /load_model.py, los usuarios primero deben ingresar al directorio del proyecto:
    cd 本地路径/ModelTrainingFramework
    
    Tenga en cuenta que los usuarios solo pueden ingresar a la carpeta ModelTrainingFramework para ejecutar el comando anterior , y si ingresan a la carpeta de capacitación para ejecutar, se informará un error porque la importación no se realizó correctamente. A continuación, el usuario solo necesita configurar algunos parámetros para comenzar a entrenar el modelo:
    python ./training/train_model.py \
      --arch=resnet20 \
      --dataset=CIFAR10 \
      --save_path=./save/resnet20/
    
    El parámetro save_path debe tener asignado un valor estándar , de lo contrario, el nombre de la carpeta donde se encuentran los archivos de registro y el punto de control no se corresponderá con el archivo arch.

  load_model.py es el programa principal, y los parámetros y rangos opcionales se muestran en la siguiente tabla.

parámetro significado   valores predeterminados Opciones/Observaciones
conjunto de datos conjunto de datos CIFAR10 CIFAR10, CIFAR100
Ruta de datos ruta del conjunto de datos ./conjuntos de datos/CIFAR10 Tanto las rutas absolutas como las relativas son aceptables.
arco estructura del modelo resnet20 resnet20, resnet34, resnet56, resnet110, vgg11(13,16,19), vgg11(13,16,19)bn
guardar_ruta Ruta de almacenamiento del punto de control ./guardar/resnet20 Tanto las rutas absolutas como las relativas son aceptables.
Semilla manual semilla de número aleatorio Ninguno Los números enteros son aceptables y no se recomienda modificarlos.
dispositivo Plataforma informática para modelos de entrenamiento Cuda cuda, CPU
imprimir_freq cuantos lotes producir 100 Los números enteros son aceptables, no excedan el número de lotes y no se recomienda la modificación
prueba_bs Tamaño del lote del conjunto de prueba 256 Los números enteros son aceptables y no se recomienda modificarlos.
tren_bs tamaño del lote del conjunto de entrenamiento 64 Los números enteros son aceptables y no se recomienda modificarlos.
trabajadores_de_prueba Número de procesos del conjunto de pruebas 0 Los números enteros son aceptables y no se recomienda modificarlos.
tren_trabajadores Número de procesos de conjuntos de entrenamiento 0 Los números enteros son aceptables y no se recomienda modificarlos.
épocas período de iteración de entrenamiento 160 No se recomienda ninguna modificación.
época_de_inicio desde que periodo 0 No se recomienda ninguna modificación.
yo La tasa de aprendizaje al inicio del entrenamiento. 0.1 No se recomienda ninguna modificación.
impulso impulso 0.9 No se recomienda ninguna modificación.
peso_decay decaimiento de peso 1.00E-04 No se recomienda ninguna modificación.

  Para simplificar la operación, se recomienda que los usuarios solo agreguen los parámetros de arch, dataset y save_path. Una vez completada la operación anterior, el punto de control, el registro de operación y la curva de entrenamiento se guardarán en la carpeta save/resnet20.

3. Análisis de ejecución de CMTF

  Tomando los parámetros operativos de la Sección 2 como ejemplo, la salida de CMTF al principio se muestra en la siguiente figura. CMTF generará primero la estructura del modelo establecida por el usuario, el conjunto de datos y la ruta de almacenamiento de los resultados de salida. Luego, realice una verificación del modelo y, finalmente, ingrese el entrenamiento iterativo de épocas.

captura de pantalla de la salida de ejecución
  Una vez completada la capacitación, la carpeta tb_log y resnet20_best.pth.tar, resnet20_checkpoint.pth.tar, curve.png, log_training_resnet20_xxxx.txt se generarán en la carpeta save/resnet20. tb_log es la carpeta disponible de tensorboardX, resnet20_best.pth.tar, resnet20_checkpoint.pth.tar son el punto de control almacenado después del entrenamiento y el punto de control óptimo durante el entrenamiento respectivamente , log_training_resnet20_xxxx.txt registra la salida del programa, curve.png es el gráfico de la curva de entrenamiento, como se muestra en la siguiente figura.
curva de entrenamiento
  La curva de entrenamiento muestra que el rendimiento del entrenamiento mejora mucho alrededor de la época 80 y la época 120, lo que se debe a la reducción del parámetro de tasa de aprendizaje lr en estas dos épocas:

if epoch in [args.epochs*0.5, args.epochs*0.75]:
            for param_group in optimizer.param_groups:
                param_group['lr'] *= 0.1

  Esta es también la única estrategia de programación lr de CMTF, que es muy simple y eficiente. Después del entrenamiento, los usuarios pueden usar training/load_model.py para cargar el punto de control y verificar :

python ./training/load_model.py \
    --arch=resnet20 \
    --save_path=./save/resnet20 \
    --resume=./save/resnet20/resnet20_checkpoint.pth.tar

  Del mismo modo, los usuarios solo pueden ingresar a la carpeta ModelTrainingFramework para ejecutar los comandos anteriores, no a la carpeta de capacitación . Una vez completada la verificación, la salida será como se muestra en la siguiente figura.

inserte la descripción de la imagen aquí

  Primero, mostrará si el punto de control se cargó con éxito, luego generará la arquitectura, el conjunto de datos y la ruta de almacenamiento del punto de control, y finalmente generará la precisión de la prueba, la época y la tasa de aprendizaje. **Si el trabajo del usuario es operar el modelo después del entrenamiento, se recomienda continuar escribiendo código directamente en load_model.py. **Al ejecutarse y verificarse en la GPU Tesla P40, la precisión top1 de resnet20, vgg16 y vgg16bn es del 92,40 % y el 93,63 % respectivamente, alcanzando la precisión de referencia de muchos trabajos académicos, y lleva aproximadamente una hora y media.

4. CMTF extendido

  La estructura del modelo de CMTF, las operaciones del modelo y los códigos del conjunto de datos se encuentran en la carpeta all_utils. Los usuarios pueden agregar nuevos conjuntos de datos y estructuras de modelos con operaciones simples.

  • conjunto de datos Agregue la función GetDataLoader() en datasets_utils.py, agregue elif dataset == 'recién agregado conjunto de datos' después del juicio if, y luego agregue el código de procesamiento del conjunto de datos.
  • estructura del modelo. Agregue el archivo de estructura del modelo en la carpeta de modelos y defina el nuevo método de creación de modelos en la función GetModel() en models_utils.py.

  Al igual que los códigos en models_utils.py y models_utils.py, los usuarios pueden agregar rápidamente sus propios modelos y conjuntos de datos en CMTF.

Agradecimientos y Clausura

  El código de registro y el estilo de CMTF provienen de Neural Network Weight Attack , y la configuración de entrenamiento proviene del replanteamiento de la poda de red :

Liu, Z., Sun, M., Zhou, T., Huang, G. y Darrell, T. (2018). Repensar el valor de la poda de redes. preimpresión de arXiv arXiv:1810.05270.

Rakin, AS, He, Z. y Fan, D. (2019). Bit-flip attack: Red neuronal aplastante con búsqueda progresiva de bits. En Actas de la Conferencia internacional IEEE/CVF sobre visión artificial (págs. 1211-1220).

  Gracias por su investigación, se recomienda encarecidamente a los usuarios que citen los dos excelentes trabajos anteriores.

  CMTF es uno de mis pocos blogs y mi primer proyecto de github. Actualmente, espero mejorar mi nivel de escritura en el blog. Si los lectores encuentran dificultades en el proceso de implementación o se sienten confundidos al leer este artículo, dejen un mensaje o agreguen mi QQ: 1106295085. Responderé el domingo por la tarde y revisaré activamente este artículo.

Supongo que te gusta

Origin blog.csdn.net/qq_39068200/article/details/130162398
Recomendado
Clasificación