第一课:LeNet学习


一、pytorch官网入门demo——实现一个图像分类器

demo用的是CIFAR10数据集,也是一个很经典的图像分类数据集,由 Hinton 的学生 Alex Krizhevsky 和 Ilya Sutskever 整理的一个用于识别普适物体的小型数据集,一共包含 10 个类别的 RGB 彩色图片。

二、代码部分

1.module.py----定义LeNet的网络结构

代码如下(示例):

import torch.nn as nn
import torch.nn.functional as F


class LeNet(nn.Module):#定义一个类 继承于副类nn.module
    def __init__(self):#初始化函数:搭建中所需要的网络结构
        super(LeNet, self).__init__()#super解决调用类问题
        self.conv1 = nn.Conv2d(3, 16, 5)#卷积层1 (通道数3 卷积核16个 卷积尺寸5)
        self.pool1 = nn.MaxPool2d(2, 2)#下采样1
        self.conv2 = nn.Conv2d(16, 32, 5)#卷积层2
        self.pool2 = nn.MaxPool2d(2, 2)#下采样2
        self.fc1 = nn.Linear(32*5*5, 120)#全连接层1 节点个数120
        self.fc2 = nn.Linear(120, 84)#全连接层2  节点个数80
        self.fc3 = nn.Linear(84, 10)#全连接层3  (CIFAR10数据集,训练集类别为10,根据训练集进行修改)

    def forward(self, x):#正向传播的过程
        x = F.relu(self.conv1(x))    # input(3, 32, 32) output(16, 28, 28) BATCH
        x = self.pool1(x)            # output(16, 14, 14) 池化层只改变尺寸 不改变通道数
        x = F.relu(self.conv2(x))    # output(32, 10, 10) 卷积之后通过relu激活函数
        x = self.pool2(x)            # output(32, 5, 5)
        x = x.view(-1, 32*5*5)       # output(32*5*5) 展平成一个向量
        x = F.relu(self.fc1(x))      # output(120) 全连接1
        x = F.relu(self.fc2(x))      # output(84)  全连接2
        x = self.fc3(x)              # output(10)
        return x


#经过卷积计算之后的尺寸大小公式:(W-F+2P/S)+1 W是输入通道 F是卷积大小 P是填0 S是步长
#验证卷积神经网络CNN各个层的输出是否与理论计算正确
# 多行注释CTRL+/    CTRL+函数可以查看不同函数的相关参数以及功能

# import torch
# input1=torch.rand([32,3,32,32])
# model=LeNet()
# print(model)
# output1=model(input1)

2.train.py----加载数据集并进行训练,训练集计算loss,测试集计算accuracy,保存训练好的网络参数

代码如下(示例):

import torch
import torchvision
import torch.nn as nn
from model import LeNet
import torch.optim as optim
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np

def main():
    transform = transforms.Compose(
        [transforms.ToTensor(),#对输入的图像数据做预处理,即由H W C (0-255) 转换为  C H W (0-1)
         transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])  #标准化

    # 50000张训练图片
    # 第一次使用时要将download设置为True才会自动去下载数据集  训练集下载到当前目录的data目录下 transform对图像预处理 表示是数据集中的训练集
    train_set = torchvision.datasets.CIFAR10(root='./data', train=True,#datasets.CIFAR10里面还有其他数据集
                                             download=False, transform=transform)
    train_loader = torch.utils.data.DataLoader(train_set, batch_size=36,#导入图片进行分批次batchsize=36 每次进行36张训练 是否打乱训练集
                                               shuffle=False, num_workers=0)

    # train_data_iter = iter(train_loader)  # 用来生成迭代器
    # train_image, train_label =  train_data_iter.next()  # 训练的照片以及标签

    # 10000张验证图片
    # 第一次使用时要将download设置为True才会自动去下载数据集
    val_set = torchvision.datasets.CIFAR10(root='./data', train=False,
                                           download=False, transform=transform)
    val_loader = torch.utils.data.DataLoader(val_set, batch_size=32,
                                             shuffle=True, num_workers=0)
    val_data_iter = iter(val_loader)#用来生成迭代器
    val_image, val_label = val_data_iter.next()#验证的照片以及标签
    
    classes = ('plane', 'car', 'bird', 'cat',
               'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

    # # functions to show an image
    # def imshow(img):
    #     img = img / 2 + 0.5  # unnormalize
    #     npimg = img.numpy()
    #     plt.imshow(np.transpose(npimg, (1, 2, 0)))#c h w 转换为H W C
    #     plt.show()
    #
    #
    # # print labels
    # batch_size=4
    # print(' '.join(f'{classes[val_label[j]]:5s}' for j in range(batch_size)))
    # # show images
    # imshow(torchvision.utils.make_grid(val_image))


    net = LeNet()#定义训练的网络模型
    loss_function = nn.CrossEntropyLoss()#定义损失函数
    optimizer = optim.Adam(net.parameters(), lr=0.001)#优化器 (所有参数训练+学习率)

    for epoch in range(5):  #一个epoch即对整个训练集进行一次训练 训练集迭代多少轮次 5轮

        running_loss = 0.0
        for step, data in enumerate(train_loader, start=0):#遍历训练集数据
            inputs, labels = data #获取训练集的图像和标签
            optimizer.zero_grad()#历史梯度清零

            # 前向传播 + 反向传播 + 优化器
            outputs = net(inputs)#前向传播
            loss = loss_function(outputs, labels)#计算损失梯度
            loss.backward()#反向传播
            optimizer.step()#优化器更新参数

            #打印耗时、损失、准确率等数据
            running_loss += loss.item()
            if step % 500 == 499:    #每500步打印一次
                with torch.no_grad():#在以下步骤中(验证过程中)不用计算每个节点的损失梯度,防止内存占用
                    outputs = net(val_image)  # [batch, 10]验证集传入网络(test_batch_size=10),output维度为[10,10]
                    predict_y = torch.max(outputs, dim=1)[1]#预测分类 以output中值最大位置对应的索引(标签)作为预测输出
                    accuracy = torch.eq(predict_y, val_label).sum().item() / val_label.size(0)#计算精确度 转换为数值

                    print('[%d, %5d] train_loss: %.3f  test_accuracy: %.3f' %#打印epoch,step,loss,accuracy
                          (epoch + 1, step + 1, running_loss / 500, accuracy))
                    running_loss = 0.0

    print('Finished Training')

    save_path = './Lenet.pth'#保存权重,训练得到的参数
    torch.save(net.state_dict(), save_path)


if __name__ == '__main__':
    main()

3.predict.py——得到训练好的网络参数后,用自己找的图像进行分类测试

代码如下(示例):

import torch
import torchvision.transforms as transforms
from PIL import Image

from model import LeNet


def main():
    transform = transforms.Compose(
        [transforms.Resize((32, 32)),#缩放和网络大小的图片
         transforms.ToTensor(),
         transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

    classes = ('plane', 'car', 'bird', 'cat',
               'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

    net = LeNet()
    net.load_state_dict(torch.load('Lenet.pth'))#载入权重文件

    #导入要测试的图像(自己找的,不在数据集中),放在源文件目录下
    im = Image.open('bird.png')
    im = transform(im)  # [C, H, W]
    im = torch.unsqueeze(im, dim=0) #[N, C, H, W] 增加维度 tensor[batch, channel, height, width] 

    with torch.no_grad():#不计算梯度
        outputs = net(im)
        predict = torch.max(outputs, dim=1)[1].numpy()
    print(classes[int(predict)])


if __name__ == '__main__':
    main()

三、额外补充

在这里插入图片描述
1.pytorch tensor的通道排序:batch channel height width :分别是批次,通道数,高度和宽度
2.lenet网络输入是灰度图像,我们案例用的是彩色图像,所以通道数会是3x32x32
3.不懂的函数可以在pytorch官网里面查看

猜你喜欢

转载自blog.csdn.net/qq_45825952/article/details/123920829