Dirección de descarga de recursos : https://download.csdn.net/download/sheziqiong/88284348
Dirección de descarga de recursos : https://download.csdn.net/download/sheziqiong/88284348
Red neuronal convolucional CNN
Contenido y requisitos del experimento.
- Escriba un programa para implementar la red neuronal convolucional LeNet-5, entrenar y reconocer la base de datos de dígitos escritos a mano MNIST y mostrar la precisión, etc.
- Elija su propia red neuronal y realice entrenamiento y reconocimiento de objetos de imagen en la base de datos CIFAR-10.
equipo de experimento
Pitón 3.7
Plataforma de desarrollo: Windows10 Visual Studio Code
Biblioteca de aprendizaje automático: torch 1.6.0 torchvision 0.7.0
Auxiliar: CUDA 10.2 para aceleración de GPU
Implementación
3.1 Implementación de LeNet-5
Usando la derivación de la clase de antorcha nn.Module, la estructura de LeNet5 se puede escribir de la siguiente manera: se llama a la función nn.Conv2d() para establecer la capa de convolución, y la función nn.Linear() se usa para realizar la operación de conexión completa. En el proceso de conducción directa, se especifican dos agrupaciones mediante la función F.max_pool2d. Después de cada capa, se llama a la función F.relu() en el resultado para activarlo y formar una nueva salida.
En el proceso de implementación de la red neuronal convolucional, se encuentra una dificultad para llamar al módulo de carga de datos de pytorch. Llame a torch.utils.data.DataLoader (), establezca el tamaño del lote, si se reorganizará aleatoriamente y num_workers (número de procesos). Dado que se utiliza Windows, la compatibilidad con subprocesos múltiples no es buena.
Proceso de capacitación: use el optimizador de funciones de optimización (usando el algoritmo de Adam) y la función de pérdida (función de entropía cruzada CrossEntropyLoss), y llame a la función backard () en la pérdida para realizar el proceso de retropropagación. Preste atención a la configuración train() de la red antes del entrenamiento y habilite la normalización y abandono por lotes para evitar que la red se sobreadapte.
Proceso de prueba: habilite el modo eval(), propague los datos de entrada a la red y tome el valor máximo de la salida como resultado de la predicción pred.
3.2 Implementación de AlexNet
La definición de red es la siguiente:
Tenga en cuenta que los datos deben procesarse previamente antes del entrenamiento y la función de procesamiento de torchvision se utiliza para cambiar su tamaño y convertirlos en un tensor. Además, se llama a la función Normalizar para transformar el tensor original del rango (0,1) al rango (-1,1).
El entrenamiento y detección de CIFAR-10 es similar al de MNIST y no se describirá nuevamente.
Resultados experimentales y análisis.
4.1 Capacitación LeNet-5 y reconocimiento de MNIST
Establezca BATCH_SIZE en 512 y entrene durante un total de 10 épocas. Cada época pasa los datos de entrenamiento y luego los datos de prueba para obtener los valores de precisión y función de pérdida. Los resultados del entrenamiento y las pruebas se guardan en LeNet.log y el modelo se guarda como LeNet.pth.
Los resultados del entrenamiento se visualizan de la siguiente manera:
4.2 Capacitación y reconocimiento de CIFAR-10 por AlexNet
Establezca BATCH_SIZE en 32 y entrene durante un total de 20 épocas. Cada época pasa los datos de entrenamiento y luego los datos de prueba para obtener los valores de precisión y función de pérdida. Los resultados del entrenamiento y las pruebas se guardan en AlexNet.log y el modelo se guarda como AlexNet.pth.
Debido a que la red AlexNet es relativamente compleja y el volumen de datos CIFAR-10 también es grande, la estructura de la red entrenada ahora se imprime de la siguiente manera para verificar si es correcta:
Primero seleccionamos aleatoriamente un lote de datos para probar los resultados del entrenamiento:
Comparando etiquetas reales y etiquetas previstas: 27 de 32 imágenes fueron juzgadas correctamente, con una tasa de precisión de aproximadamente el 84%.
GroundTruth: cat ship ship airplane frog frog automobile frog cat automobile airplane truck dog horse truck ship dog horse ship frog horse airplane deer truck
dog bird deer airplane truck frog frog dog
Predicted: cat ship ship airplane frog frog truck frog cat automobile airplane truck dog horse truck ship dog horse ship frog horse bird airplane truck deer frog deer airplane truck frog frog dog
Además, los resultados de las pruebas con 50.000 datos de entrenamiento mostraron una tasa de precisión del 92% y el resultado de 10.000 datos de pruebas nuevos fue del 77%. Entre las diez etiquetas, ship tiene la tasa de precisión más alta, 91%, y cat tiene la tasa de precisión más baja, casi 60%.
Dirección de descarga de recursos : https://download.csdn.net/download/sheziqiong/88284348
Dirección de descarga de recursos : https://download.csdn.net/download/sheziqiong/88284348