Resource download address : https://download.csdn.net/download/sheziqiong/88284348
Resource download address : https://download.csdn.net/download/sheziqiong/88284348
CNN convolutional neural network
Experiment content and requirements
- Write a program to implement the LeNet-5 convolutional neural network, train and recognize the MNIST handwritten digit database, and display the accuracy, etc.
- Choose your own neural network and perform image object training and recognition on the CIFAR-10 database.
experiment equipment
Python 3.7
Development platform: Windows10 Visual Studio Code
Machine learning library: torch 1.6.0 torchvision 0.7.0
Auxiliary: CUDA 10.2 for GPU acceleration
Implementation
3.1 LeNet-5 implementation
Using the derivation of the nn.Module class of torch, the structure of LeNet5 can be written as follows: the nn.Conv2d() function is called to set the convolution layer, and the nn.Linear() function is used to perform the full connection operation. In the process of forward conduction, two poolings are specified, using the F.max_pool2d function. After each layer, the F.relu() function is called on the result to activate it to form a new output.
In the process of implementing the convolutional neural network, calling the data loading module of pytorch is a difficulty encountered. Call torch.utils.data.DataLoader(), set the batch size, whether to randomly reorganize, and num_workers (number of processes). Since Windows is used, multi-threading support is not good.
Training process: Use the optimization function optimizer (using the Adam algorithm) and the loss function (cross entropy function CrossEntropyLoss), and call the backard() function on the loss to perform the backpropagation process. Pay attention to the train() settings of the network before training, and enable batchnormalization and dropout to prevent the network from overfitting.
Test process: Enable the eval() mode, propagate the input data into the network, and take the maximum value of the output as the prediction result pred.
3.2 AlexNet implementation
The network definition is as follows:
Note that the data must be preprocessed before training, and the processing function of torchvision is used to resize and convert it into a tensor. In addition, the Normalize function is called to transform the original tensor from (0,1) to the (-1,1) range.
The training and detection of CIFAR-10 is similar to that of MNIST and will not be described again.
Experimental results and analysis
4.1 LeNet-5 training and recognition of MNIST
Set BATCH_SIZE to 512 and train for a total of 10 epochs. Each epoch passes the training data and then the test data to obtain the accuracy and loss function values. The output results of training and testing are saved in LeNet.log, and the model is saved as LeNet.pth.
The training results are visualized as follows:
4.2 Training and recognition of CIFAR-10 by AlexNet
Set BATCH_SIZE to 32 and train for a total of 20 epochs. Each epoch passes the training data and then the test data to obtain the accuracy and loss function values. The output results of training and testing are saved in AlexNet.log, and the model is saved as AlexNet.pth.
Because the AlexNet network is relatively complex and the CIFAR-10 data volume is also large, the trained network structure is now printed as follows to verify whether it is correct:
We first randomly select a batch of data for testing on the training results:
Comparing actual labels and predicted labels: 27 out of 32 images were correctly judged, with an accuracy rate of approximately 84%.
GroundTruth: cat ship ship airplane frog frog automobile frog cat automobile airplane truck dog horse truck ship dog horse ship frog horse airplane deer truck
dog bird deer airplane truck frog frog dog
Predicted: cat ship ship airplane frog frog truck frog cat automobile airplane truck dog horse truck ship dog horse ship frog horse bird airplane truck deer frog deer airplane truck frog frog dog
In addition, the test results on 50,000 training data showed an accuracy rate of 92%, and the result on 10,000 new test data was 77%. Among the ten labels, ship has the highest accuracy rate of 91%, and cat has the lowest accuracy rate of nearly 60%.
Resource download address : https://download.csdn.net/download/sheziqiong/88284348
Resource download address : https://download.csdn.net/download/sheziqiong/88284348