Python implements convolutional neural network LeNet-5 and AlexNet training and recognition

Resource download address : https://download.csdn.net/download/sheziqiong/88284348
Resource download address : https://download.csdn.net/download/sheziqiong/88284348

CNN convolutional neural network

Experiment content and requirements

  • Write a program to implement the LeNet-5 convolutional neural network, train and recognize the MNIST handwritten digit database, and display the accuracy, etc.
  • Choose your own neural network and perform image object training and recognition on the CIFAR-10 database.

experiment equipment

Python 3.7

Development platform: Windows10 Visual Studio Code

Machine learning library: torch 1.6.0 torchvision 0.7.0

Auxiliary: CUDA 10.2 for GPU acceleration

Implementation

3.1 LeNet-5 implementation

Using the derivation of the nn.Module class of torch, the structure of LeNet5 can be written as follows: the nn.Conv2d() function is called to set the convolution layer, and the nn.Linear() function is used to perform the full connection operation. In the process of forward conduction, two poolings are specified, using the F.max_pool2d function. After each layer, the F.relu() function is called on the result to activate it to form a new output.

The external link image transfer failed. The source site may have an anti-leeching mechanism. It is recommended to save the image and upload it directly.

In the process of implementing the convolutional neural network, calling the data loading module of pytorch is a difficulty encountered. Call torch.utils.data.DataLoader(), set the batch size, whether to randomly reorganize, and num_workers (number of processes). Since Windows is used, multi-threading support is not good.

Training process: Use the optimization function optimizer (using the Adam algorithm) and the loss function (cross entropy function CrossEntropyLoss), and call the backard() function on the loss to perform the backpropagation process. Pay attention to the train() settings of the network before training, and enable batchnormalization and dropout to prevent the network from overfitting.

Test process: Enable the eval() mode, propagate the input data into the network, and take the maximum value of the output as the prediction result pred.

3.2 AlexNet implementation

The network definition is as follows:

Note that the data must be preprocessed before training, and the processing function of torchvision is used to resize and convert it into a tensor. In addition, the Normalize function is called to transform the original tensor from (0,1) to the (-1,1) range.

The training and detection of CIFAR-10 is similar to that of MNIST and will not be described again.

Experimental results and analysis

4.1 LeNet-5 training and recognition of MNIST

Set BATCH_SIZE to 512 and train for a total of 10 epochs. Each epoch passes the training data and then the test data to obtain the accuracy and loss function values. The output results of training and testing are saved in LeNet.log, and the model is saved as LeNet.pth.

The training results are visualized as follows:

4.2 Training and recognition of CIFAR-10 by AlexNet

Set BATCH_SIZE to 32 and train for a total of 20 epochs. Each epoch passes the training data and then the test data to obtain the accuracy and loss function values. The output results of training and testing are saved in AlexNet.log, and the model is saved as AlexNet.pth.

Because the AlexNet network is relatively complex and the CIFAR-10 data volume is also large, the trained network structure is now printed as follows to verify whether it is correct:

We first randomly select a batch of data for testing on the training results:

Comparing actual labels and predicted labels: 27 out of 32 images were correctly judged, with an accuracy rate of approximately 84%.

GroundTruth:  cat  ship  ship airplane  frog  frog  automobile  frog   cat   automobile  airplane truck   dog horse truck  ship   dog horse  ship  frog horse  airplane  deer  truck
dog   bird  deer airplane truck  frog  frog   dog 
Predicted:    cat  ship  ship airplane  frog  frog  truck     frog   cat   automobile airplane  truck   dog horse truck  ship   dog horse  ship  frog horse  bird      airplane truck  deer  frog  deer airplane truck  frog  frog   dog 

In addition, the test results on 50,000 training data showed an accuracy rate of 92%, and the result on 10,000 new test data was 77%. Among the ten labels, ship has the highest accuracy rate of 91%, and cat has the lowest accuracy rate of nearly 60%.

Resource download address : https://download.csdn.net/download/sheziqiong/88284348
Resource download address : https://download.csdn.net/download/sheziqiong/88284348

Guess you like

Origin blog.csdn.net/newlw/article/details/132625381