基于Python实现的CNN卷积神经网络训练与识别

资源下载地址:https://download.csdn.net/download/sheziqiong/85626941

实验内容和要求

  • 编写程序,实现 LeNet-5 卷积神经网络,对 MNIST 手写数字数据库进行训练与识别,展示准确率等。
  • 自己选择神经网络,对 CIFAR-10 数据库进行图像物体训练与识别。

实验器材

Python 3.7

开发平台:Windows10 Visual Studio Code

机器学习库:torch 1.6.0 torchvision 0.7.0

辅助:CUDA 10.2,用于进行 GPU 加速

具体实现

3.1 LeNet-5 实现

使用 torch 的 nn.Module 类的派生,可以编写 LeNet5 的结构如下:其中调用 nn.Conv2d()函数进行卷积层设置,用 nn.Linear()函数进行全连接操作。在正向传导的过程中,规定了两次池化,使用 F.max_pool2d 函数。每一经过一层,对结果调用 F.relu()函数进行激活,形成新的输出。

在实现卷积神经网络的过程中,调用 pytorch 的数据加载模块的部分是遇到的一个难点。调用 torch.utils.data.DataLoader(),设定批的大小,是否随机重组,以及 num_workers(进程数),由于使用的是 Windows 所以对多线程支持的并不好。

训练过程:使用优化函数 optimizer(选用 Adam 算法)和损失函数(交叉熵函数 CrossEntropyLoss),对 loss 调用 backard()函数进行反向传播过程。注意在训练前对网络进行 train()设置,启用 batchnormalization 和 dropout,防止网络过拟合。

测试过程:启用 eval()模式,对输入数据进入网络进行传播,对输出的 output 取极大值作为预测结果 pred。

3.2 AlexNet 实现

网络定义如下图:

扫描二维码关注公众号,回复: 14267493 查看本文章

注意在训练前也要对数据先做预处理,利用 torchvision 的处理函数进行 resize 和转换成张量(tensor)的处理。另外调用 Normalize 函数,将原来的 tensor 从(0,1) 变换到(-1,1)区间。

对 CIFAR-10 的训练和检测与 MNIST 的思想是类似的,不再赘述。

实验结果与分析

4.1 LeNet-5 对 MNIST 的训练与识别

设置 BATCH_SIZE 为 512,总共训练 10 个 epoch。每次一个 epoch 在过完一遍训练数据之后再过一遍测试数据,得到一次准确度和损失函数的值。训练和测试的输出结果保存在 LeNet.log 里,模型保存为 LeNet.pth。

对训练结果进行可视化处理如下:

4.2 AlexNet 对 CIFAR-10 的训练与识别

设置 BATCH_SIZE 为 32,总共训练 20 个 epoch。每次一个 epoch 在过完一遍训练数据之后再过一遍测试数据,得到一次准确度和损失函数的值。训练和测试的输出结果保存 AlexNet.log 里,模型保存为 AlexNet.pth。

因为 AlexNet 网络比较复杂,而且 CIFAR-10 数据量也较大,现将训练的网络结构打印如下验证是否正确:

我们对训练后的结果先随机选择一批数据进行测试:

对比实际标签和预测标签:在 32 张图中正确判断了 27 张,正确率约为 84%。

GroundTruth:  cat  ship  ship airplane  frog  frog  automobile  frog   cat   automobile  airplane truck   dog horse truck  ship   dog horse  ship  frog horse  airplane  deer  truck
dog   bird  deer airplane truck  frog  frog   dog 
Predicted:    cat  ship  ship airplane  frog  frog  truck     frog   cat   automobile airplane  truck   dog horse truck  ship   dog horse  ship  frog horse  bird      airplane truck  deer  frog  deer airplane truck  frog  frog   dog 

另外,在五万张训练数据中测试的结果显示正确率为 92%,在一万张新的测试数据中的结果为 77%。在十个标签中,正确率最高的 ship 达到了 91%,最低的 cat 也有近六成的判断正确率。

资源下载地址:https://download.csdn.net/download/sheziqiong/85626941

猜你喜欢

转载自blog.csdn.net/newlw/article/details/125256884