Python implementa el entrenamiento y reconocimiento de la red neuronal convolucional LeNet-5 y AlexNet

Dirección de descarga de recursos : https://download.csdn.net/download/sheziqiong/88284348
Dirección de descarga de recursos : https://download.csdn.net/download/sheziqiong/88284348

Red neuronal convolucional CNN

Contenido y requisitos del experimento.

  • Escriba un programa para implementar la red neuronal convolucional LeNet-5, entrenar y reconocer la base de datos de dígitos escritos a mano MNIST y mostrar la precisión, etc.
  • Elija su propia red neuronal y realice entrenamiento y reconocimiento de objetos de imagen en la base de datos CIFAR-10.

equipo de experimento

Pitón 3.7

Plataforma de desarrollo: Windows10 Visual Studio Code

Biblioteca de aprendizaje automático: torch 1.6.0 torchvision 0.7.0

Auxiliar: CUDA 10.2 para aceleración de GPU

Implementación

3.1 Implementación de LeNet-5

Usando la derivación de la clase de antorcha nn.Module, la estructura de LeNet5 se puede escribir de la siguiente manera: se llama a la función nn.Conv2d() para establecer la capa de convolución, y la función nn.Linear() se usa para realizar la operación de conexión completa. En el proceso de conducción directa, se especifican dos agrupaciones mediante la función F.max_pool2d. Después de cada capa, se llama a la función F.relu() en el resultado para activarlo y formar una nueva salida.

La transferencia de la imagen del enlace externo falló. El sitio de origen puede tener un mecanismo anti-leeching. Se recomienda guardar la imagen y cargarla directamente.

En el proceso de implementación de la red neuronal convolucional, se encuentra una dificultad para llamar al módulo de carga de datos de pytorch. Llame a torch.utils.data.DataLoader (), establezca el tamaño del lote, si se reorganizará aleatoriamente y num_workers (número de procesos). Dado que se utiliza Windows, la compatibilidad con subprocesos múltiples no es buena.

Proceso de capacitación: use el optimizador de funciones de optimización (usando el algoritmo de Adam) y la función de pérdida (función de entropía cruzada CrossEntropyLoss), y llame a la función backard () en la pérdida para realizar el proceso de retropropagación. Preste atención a la configuración train() de la red antes del entrenamiento y habilite la normalización y abandono por lotes para evitar que la red se sobreadapte.

Proceso de prueba: habilite el modo eval(), propague los datos de entrada a la red y tome el valor máximo de la salida como resultado de la predicción pred.

3.2 Implementación de AlexNet

La definición de red es la siguiente:

Tenga en cuenta que los datos deben procesarse previamente antes del entrenamiento y la función de procesamiento de torchvision se utiliza para cambiar su tamaño y convertirlos en un tensor. Además, se llama a la función Normalizar para transformar el tensor original del rango (0,1) al rango (-1,1).

El entrenamiento y detección de CIFAR-10 es similar al de MNIST y no se describirá nuevamente.

Resultados experimentales y análisis.

4.1 Capacitación LeNet-5 y reconocimiento de MNIST

Establezca BATCH_SIZE en 512 y entrene durante un total de 10 épocas. Cada época pasa los datos de entrenamiento y luego los datos de prueba para obtener los valores de precisión y función de pérdida. Los resultados del entrenamiento y las pruebas se guardan en LeNet.log y el modelo se guarda como LeNet.pth.

Los resultados del entrenamiento se visualizan de la siguiente manera:

4.2 Capacitación y reconocimiento de CIFAR-10 por AlexNet

Establezca BATCH_SIZE en 32 y entrene durante un total de 20 épocas. Cada época pasa los datos de entrenamiento y luego los datos de prueba para obtener los valores de precisión y función de pérdida. Los resultados del entrenamiento y las pruebas se guardan en AlexNet.log y el modelo se guarda como AlexNet.pth.

Debido a que la red AlexNet es relativamente compleja y el volumen de datos CIFAR-10 también es grande, la estructura de la red entrenada ahora se imprime de la siguiente manera para verificar si es correcta:

Primero seleccionamos aleatoriamente un lote de datos para probar los resultados del entrenamiento:

Comparando etiquetas reales y etiquetas previstas: 27 de 32 imágenes fueron juzgadas correctamente, con una tasa de precisión de aproximadamente el 84%.

GroundTruth:  cat  ship  ship airplane  frog  frog  automobile  frog   cat   automobile  airplane truck   dog horse truck  ship   dog horse  ship  frog horse  airplane  deer  truck
dog   bird  deer airplane truck  frog  frog   dog 
Predicted:    cat  ship  ship airplane  frog  frog  truck     frog   cat   automobile airplane  truck   dog horse truck  ship   dog horse  ship  frog horse  bird      airplane truck  deer  frog  deer airplane truck  frog  frog   dog 

Además, los resultados de las pruebas con 50.000 datos de entrenamiento mostraron una tasa de precisión del 92% y el resultado de 10.000 datos de pruebas nuevos fue del 77%. Entre las diez etiquetas, ship tiene la tasa de precisión más alta, 91%, y cat tiene la tasa de precisión más baja, casi 60%.

Dirección de descarga de recursos : https://download.csdn.net/download/sheziqiong/88284348
Dirección de descarga de recursos : https://download.csdn.net/download/sheziqiong/88284348

Supongo que te gusta

Origin blog.csdn.net/newlw/article/details/132625381
Recomendado
Clasificación