[PyTorch Practical Exercise] Use the Cifar10 data set to train the LeNet5 network and implement image classification (with code)

0. Preface

In accordance with international practice, I would like to first declare: This article is only my own understanding of learning. Although I refer to the valuable insights of others, the content may contain inaccuracies. If you find errors in the article, I hope to criticize and correct them so that we can make progress together.

This article is a practical exercise using the LeNet5 network to implement image classification based on the PyTorch framework. The training data set uses Cifar10. It aims to enhance the understanding of deep learning, especially convolutional neuron networks, through practical operations.

This article is a complete nanny-level learning guide. As long as you have the most basic deep learning knowledge, you can follow this guide: use the PyTorch library to build the LeNet5 network from scratch, then train it, and finally be able to recognize physical objects in real-shot images.

1. Cifar10 data set

The Cifar10 data set was created in the 1990s by Alex Krizhevsky and Ilya Sutskever, students of computer scientist Geoffrey Hinton. Cifar10 is an image classification dataset containing 10 categories. Each category contains 6000 color images of 32x32 pixels, with a total of 60000 images, of which 50000 images are used to train the network model (training group) and 10000 images are used for validation. Network model (validation group).

The name Cifar10 stands for the 10-category image set made by the Canadian Institute for Advanced Research (Canadian Institute for Advanced Research), and the following Cifar100 is the 100-category image set.

1.1 Cifar10 data set download

Use torchvisiondirect download Cifar10:

from torchvision import datasets
from torchvision import transforms

data_path = 'CIFAR10/IMG_file'
cifar10 = datasets.CIFAR10(root=data_path, train=True, download=True,transform=transforms.ToTensor())   #首次下载时download设为true

datasets.CIFAR10Parameters in:

  • root: path to download file
  • train: If it is True, it downloads the training group data, with a total of 50,000 images; if it is False, it downloads the verification group data, with a total of 10,000 images.
  • download: It needs to be set to True when downloading new data. If the data has been downloaded, it can be set to False.
  • Transform: Transform the image data. transforms.ToTensor()The image data specified here will be converted into Tensor, and the data range is adjusted to 0~1, saving us the need to write another line of normalization code.
1.2 Cifar10 data set analysis

After downloading, you can take a look at the specific contents of the Cifar10 data set:

print(type(cifar10))
print(cifar10[0])
------------------------输出------------------------------------
<class 'torchvision.datasets.cifar.CIFAR10'>
(tensor([[[0.2314, 0.1686, 0.1961,  ..., 0.6196, 0.5961, 0.5804],
         [0.0627, 0.0000, 0.0706,  ..., 0.4824, 0.4667, 0.4784],
         [0.0980, 0.0627, 0.1922,  ..., 0.4627, 0.4706, 0.4275],
         ...,
         [0.8157, 0.7882, 0.7765,  ..., 0.6275, 0.2196, 0.2078],
         [0.7059, 0.6784, 0.7294,  ..., 0.7216, 0.3804, 0.3255],
         [0.6941, 0.6588, 0.7020,  ..., 0.8471, 0.5922, 0.4824]],

        [[0.2431, 0.1804, 0.1882,  ..., 0.5176, 0.4902, 0.4863],
         [0.0784, 0.0000, 0.0314,  ..., 0.3451, 0.3255, 0.3412],
         [0.0941, 0.0275, 0.1059,  ..., 0.3294, 0.3294, 0.2863],
         ...,
         [0.6667, 0.6000, 0.6314,  ..., 0.5216, 0.1216, 0.1333],
         [0.5451, 0.4824, 0.5647,  ..., 0.5804, 0.2431, 0.2078],
         [0.5647, 0.5059, 0.5569,  ..., 0.7216, 0.4627, 0.3608]],

        [[0.2471, 0.1765, 0.1686,  ..., 0.4235, 0.4000, 0.4039],
         [0.0784, 0.0000, 0.0000,  ..., 0.2157, 0.1961, 0.2235],
         [0.0824, 0.0000, 0.0314,  ..., 0.1961, 0.1961, 0.1647],
         ...,
         [0.3765, 0.1333, 0.1020,  ..., 0.2745, 0.0275, 0.0784],
         [0.3765, 0.1647, 0.1176,  ..., 0.3686, 0.1333, 0.1333],
         [0.4549, 0.3686, 0.3412,  ..., 0.5490, 0.3294, 0.2824]]]), 6)

Process finished with exit code 0

It can be seen that Cifar10 has its own separate data type torchvision.datasets.cifar.CIFAR10, and its structure is similar to a list.

If one of the elements is output, such as the first one cifar10[0], it contains:

  • A tensor with dimensions [3,32,32] (because Transform has specified ToTensor above), this is the RGB three-channel image data
  • A scalar data label, here is 6, this data represents the true classification of the image, and its corresponding relationship is as follows:
    Insert image description here

Here we can also use matplotlib to convert the tensor data of the image back to the image to see what the image with label 6 looks like:

from torchvision import datasets
import matplotlib.pyplot as plt
from torchvision import transforms

data_path = 'CIFAR10/IMG_file'
cifar10 = datasets.CIFAR10(root=data_path, train=True, download=False,transform=transforms.ToTensor())   #首次下载时download设为true

# print(type(cifar10))
# print(cifar10[0])

img,label = cifar10[0]
plt.imshow(img.permute(1,2,0))
plt.show()

The output is:
Insert image description here
Yes, this is a Frog with label 6. A 32×32 pixel image can only do this.

It is used here .permute()because the dimensions of the original data are [channel3, H32, W32], and the .imshow()required input dimensions should be [H, W, channel]. The dimension order of the original data needs to be adjusted.

2. LeNet5 network

LeNet5 was proposed by Yann LeCun in the early 1990s and is a classic convolutional neural network. LeNet5 consists of 7 layers of neural network, including 2 convolutional layers, 2 pooling layers and 3 fully connected layers. It (in the context of the era) creatively used convolutional layers and pooling layers to extract features from the input, reducing the number of parameters while enhancing the network's translation and rotation invariance to the input image.

LeNet5 is widely used in handwritten digit recognition and can also be used for other image classification tasks. Although the current deep convolutional neural network has better performance than LeNet5, LeNet5 has important educational significance for learning the basic principles and methods of convolutional neural networks .

2.1 Network structure of LeNet5

The network structure of LeNet5 is as follows:
Please add image description

The input of LeNet5 is a 32x32 image:

  • The first layer is a convolutional layer, including 6 5x5 convolution kernels, and the output feature map is 28x28
  • The second layer is a 2x2 max pooling layer, which reduces the feature map size by half to 14×14.
  • The third layer is another convolutional layer, including 16 5x5 convolution kernels, and the output feature map is 10x10
  • The fourth layer is the same as the second layer, reducing the size of the feature map by half to 5×5
  • The fifth layer is a fully connected layer containing 120 neurons.
  • The sixth layer is another fully connected layer, containing 84 neurons.
  • The last layer is the output layer, which contains 10 neurons, and each neuron corresponds to a label.
2.2 LeNet5 network coding based on PyTorch

According to the LeNet5 network structure above, write the code as follows:

import torch.nn as nn

class LeNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(in_channels=3, out_channels=6, kernel_size=5),  # 由于图片为RGB彩图,channel_in = 3
            #输出张量为 Batch(1)*Channel(6)*H(28)*W(28)
            nn.Sigmoid(),
            nn.MaxPool2d(kernel_size=2, stride=2),
            # 输出张量为 Batch(1)*Channel(6)*H(14)*W(14)
            nn.Conv2d(in_channels=6,out_channels= 16,kernel_size= 5),
            # 输出张量为 Batch(1)*Channel(16)*H(10)*W(10)
            nn.Sigmoid(),
            nn.MaxPool2d(kernel_size=2, stride=2),
            # 输出张量为 Batch(1)*Channel(16)*H(5)*W(5)
            nn.Conv2d(in_channels=16, out_channels=120,kernel_size=5),
            # 输出张量为 Batch(1)*Channel(120)*H(1)*W(1)
            nn.Flatten(),
            # 将输出一维化,用于后面的全连接网络输入
            nn.Linear(120, 84),
            nn.Sigmoid(),
            nn.Linear(84, 10)
        )

    def forward(self, x):
        return self.net(x)

3. LeNet5 network training and output verification

3.1 LeNet5 network training

Since my computer does not have a GPU, training with the CPU version of PyTorch data is very slow. I only took the first 2000 data of Cifar10 for training (T_T)

small_cifar10 = []
for i in range(2000):
    small_cifar10.append(cifar10[i])

Training related settings are as follows:

  • Loss function: Cross entropy loss functionnn.CrossEntropyLoss()
  • Optimization method: stochastic gradient descenttorch.optim.SGD()
  • Epoch and learning rate: This is a troublesome part. At present, I have not found a good way to set epoch and lr better in the early stage. I can only try it step by step. In order not to waste each training, we can save the weights of each training, and the next training will be based on the last result. For methods of saving and loading weights, please refer to previous blogs: Learn Pytorch loading weights.load_state_dict() and saving weights.save() through examples . The figure below shows my exploration process: the value of lr gradually decreased from about 1e-5 to 2e-7, the total number of epochs was about 3000, and the loss value dropped from the initial 10000 to less than 100.

In this part of the training process, I forgot to completely record the detailed parameters (epoch and lr) of each step. If you need it, you can leave your email and I will send you the trained weights. Readers can also explore better training parameters.

Insert image description here

3.2 LeNet5 network verification

The exciting time is coming! Now let’s verify whether our trained network can accurately identify the target image!

The image I chose is the G6 model launched by Xpeng Motors in 2023 for verification. The image is as follows:
Insert image description here
Load the weight file we trained and input the image into the model:

def img_totensor(img_file):
    img = Image.open(img_file)
    transform = transforms.Compose([transforms.ToTensor(), transforms.Resize((32, 32))])
    img_tensor = transform(img).unsqueeze(0)  #这里要升维,对应增加batch维度

    return img_tensor

test_model = LeNet()
test_model.load_state_dict(torch.load('CIFAR10/small2000_8.pth'))

img1 = img_totensor('1.jpg')
img2 = img_totensor('2.jpg')
img3 = img_totensor('3.jpg')
img4 = img_totensor('4.jpg')

print(test_model(img1))
print(test_model(img2))
print(test_model(img3))
print(test_model(img4))

The final output is as follows:

tensor([[ 8.4051, 12.0952, -7.9274,  0.3868, -3.0866, -4.7883, -1.6089, -3.6484,
         -1.1387,  4.7348]], grad_fn=<AddmmBackward0>)
tensor([[-1.1992, 17.4531, -2.7929, -6.0410, -1.7589, -2.6942, -3.6753, -2.6800,
          3.6378,  2.4267]], grad_fn=<AddmmBackward0>)
tensor([[ 1.7580, 10.6321, -5.3922, -0.4557, -2.0147, -0.5974, -0.5785, -4.7977,
         -1.2916,  5.4786]], grad_fn=<AddmmBackward0>)
tensor([[10.5689,  6.2413, -0.9554, -4.4162,  1.0807, -7.9541, -5.3185, -6.0609,
          5.1129,  4.2243]], grad_fn=<AddmmBackward0>)

Let’s interpret this output:

  • The 1st, 2nd, and 3rd images correspond to the maximum value of the output tensor in [1]the element (counting from 0), that is, the corresponding label value is 1, the true classification is Car, and the prediction is correct.
  • The output prediction error of the fourth image is wrong, and the maximum value is in the [0]element. LeNet5 thinks that this image is an Airplane.

Although this accuracy is not high, don’t forget that I only used the first 2000 data of Cifar10 for training; and the LeNet5 network input is a 32×32 image, such as the frog above, which is very difficult even for people to distinguish. task.

4. Complete code

4.1 Training code
#文件命名为 CIFAR10_main.py 后面验证时需要调用
from torchvision import datasets
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from torchvision import transforms
from tqdm import tqdm


data_path = 'CIFAR10/IMG_file'
cifar10 = datasets.CIFAR10(data_path, train=True, download=False,transform=transforms.ToTensor())   #首次下载时download设为true


class LeNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(in_channels=3, out_channels=6, kernel_size=5),  # 由于图片为RGB彩图,channel_in = 3
            #输出张量为 Batch(1)*Channel(6)*H(28)*W(28)
            nn.Sigmoid(),
            nn.MaxPool2d(kernel_size=2, stride=2),
            # 输出张量为 Batch(1)*Channel(6)*H(14)*W(14)
            nn.Conv2d(in_channels=6,out_channels= 16,kernel_size= 5),
            # 输出张量为 Batch(1)*Channel(16)*H(10)*W(10)
            nn.Sigmoid(),
            nn.MaxPool2d(kernel_size=2, stride=2),
            # 输出张量为 Batch(1)*Channel(16)*H(5)*W(5)
            nn.Conv2d(in_channels=16, out_channels=120,kernel_size=5),
            # 输出张量为 Batch(1)*Channel(120)*H(1)*W(1)
            nn.Flatten(),
            # 将输出一维化,用于后面的全连接网络输入
            nn.Linear(120, 84),
            nn.Sigmoid(),
            nn.Linear(84, 10)
        )

    def forward(self, x):
        return self.net(x)

if __name__ == '__main__':
    model = LeNet()
    model.load_state_dict(torch.load('CIFAR10/small2000_7.pth'))

    loss = nn.CrossEntropyLoss()
    opt = torch.optim.SGD(model.parameters(),lr=2e-7)


    small_cifar10 = []
    for i in range(2000):
        small_cifar10.append(cifar10[i])

    for epoch in range(1000):
        opt.zero_grad()
        total_loss = torch.tensor([0])
        for img,label in tqdm(small_cifar10):
            output = model(img.unsqueeze(0))
            label = torch.tensor([label])
            LeNet_loss = loss(output, label)
            total_loss = total_loss + LeNet_loss
            LeNet_loss.backward()
            opt.step()

        total_loss_numpy = total_loss.detach().numpy()
        plt.scatter(epoch,total_loss_numpy,c='b')
        print(total_loss)
        print("epoch=",epoch)


    torch.save(model.state_dict(),'CIFAR10/small2000_8.pth')
    plt.show()

4.1 Verification code
import torch
from torchvision import transforms
from PIL import Image
from CIFAR10_main import LeNet

def img_totensor(img_file):
    img = Image.open(img_file)
    transform = transforms.Compose([transforms.ToTensor(), transforms.Resize((32, 32))])
    img_tensor = transform(img).unsqueeze(0)  #这里要升维,对应增加batch维度

    return img_tensor

test_model = LeNet()
test_model.load_state_dict(torch.load('CIFAR10/small2000_8.pth'))

img1 = img_totensor('1.jpg')
img2 = img_totensor('2.jpg')
img3 = img_totensor('3.jpg')
img4 = img_totensor('4.jpg')

print(test_model(img1))
print(test_model(img2))
print(test_model(img3))
print(test_model(img4))

Guess you like

Origin blog.csdn.net/m0_49963403/article/details/133365347