基于Python ResNet18 网络的十分类任务【100011286】

基于 ResNet18 网络的十分类任务

1 任务描述

1.1 作业要求

完成 10 类图片的分类问题,图像示例及类别如下:

图 1:图像示例

1.2 作业完成情况

本次大作业我使用 python 语言,利用 pytorch 框架搭建 Resnet18 网络进行分类任务,最终实现了 90% 左右的准确率。

2 问题建模

本次大作业我最终采用了 ResNet18 作为最终的网络架构,虽然 torchvision.models 有对应的函数,但是为了我还是选择自己手动编写这一网络。

2.1ResidualBlock

结构图如图 2(a)所示。优点:使用普通的连接,上层的梯度必须要一层一层传回来,而使用残差连接,相当于中间有了一条更短的路,梯度能够从这条更短的路传回来,避免了梯度过小的情况。拟合公式:H(X)=F(X)+X

(a)ResidualBlock (b)ResNet18

图 2:神经网络结构图

2.2ResNet18

通过若干 ResidualBlock 的堆叠,可以构造出如图 2(b)所示的残差网络 ResNet18.

其各 Layer 参数如下表

表 1:LayersofResNet18

Layer(type) OutputShape Param#
Conv2d-1 [-1,64,32,32] 576
BatchNorm2d-2 [-1,64,32,32] 128
ReLU-3 [-1,64,32,32] 0
Conv2d-4 [-1,64,32,32] 36,864
BatchNorm2d-5 [-1,64,32,32] 128
ReLU-6 [-1,64,32,32] 0
Conv2d-7 [-1,64,32,32] 36,864
BatchNorm2d-8 [-1,64,32,32] 128
ResidualBlock-9 [-1,64,32,32] 0
Conv2d-10 [-1,64,32,32] 36,864
续表
Layer(type) OutputShape Param#
BatchNorm2d-ll [-1,64,32,32] 128
ReLU-12 [-1,64,32,32] 0
Conv2d-13 [-1,64,32,32] 36,864
BatchNorm2d-14 [-1,64,32,32] 128
ResidualBlock-15 [-1,64,32,32] 0
Conv2d-16 [-1,128,16,16] 73,728
BatchNorm2d-17 [-1,128,16,16] 256
ReLU-18 [-1,128,16,16] 0
Conv2d-19 [-1,128,16,16] 147,456
BatchNorm2d-20 [-1,128,16,16] 256
Conv2d-21 [-1,128,16,16] 8,192
BatchNorm2d-22 [-1,128,16,16] 256
ResidualBlock-23 [-1,128,16,16] 0
Conv2d-24 [-1,128,16,16] 147,456
BatchNorm2d-25 [-1,128,16,16] 256
ReLU-26 [-1,128,16,16] 0
Conv2d-27 [-1,128,16,16] 147,456
BatchNorm2d-28 [-1,128,16,16] 256
ResidualBlock-29 [-1,128,16,16] 0
Conv2d-30 [-1,256,8,8] 294,912
BatchNorm2d-31 [-1,256,8,8] 512
ReLU-32 [-1,256,8,8] 0
Conv2d-33 [-1,256,8,8] 589,824
BatchNorm2d-34 [-1,256,8,8] 512
Conv2d-35 [-1,256,8,8] 32,768
BatchNorm2d-36 [-1,256,8,8] 512
ResidualBlock-37 [-1,256,8,8] 0
Conv2d-38 [-1,256,8,8] 589,824
BatchNorm2d-39 [-1,256,8,8] 512
ReLU-40 [-1,256,8,8] 0
Conv2d-41 [-1,256,8,8] 589,824
BatchNorm2d-42 [-1,256,8,8] 512
ResidualBlock-43 [-1,256,8,8] 0
续表
Layer(type) OutputShape Param#
Conv2d-44 [-1,512,4,4] 1,179,648
BatchNorm2d-45 [-1,512,4,4] 1,024
ReLU-46 [-1,512,4,4] 0
Conv2d-47 [-1,512,4,4] 2,359,296
BatchNorm2d-48 [-1,512,4,4] 1,024
Conv2d-49 [-1,512,4,4] 131,072
BatchNorm2d-50 [-1,512,4,4] 1024
ResidualBlock-51 [-1,512,4,4] 0
Conv2d-52 [-1,512,4,4] 2,359,296
BatchNorm2d-53 [-1,512,4,4] 1,024
ReLU-54 [-1,512,4,4] 0
Conv2d-55 [-1,512,4,4] 2,359,296
BatchNorm2d-56 [-1,512,4,4] 1,024
ResidualBlock-57 [-1,512,4,4] 0
Linear-58 [-1,10] 5,130

ResNet18 参数:

以上数据根据 torchsummary 中自带的 summary 函数生成,可知该网络共有 58 层,总共 11,172,810 个参数。

3 算法设计和实现

3.1 优化器的选择

本次大作业选用的是 Adam 优化器,Adam 本质上是带有动量项的 RMSprop,它利用梯度的一阶矩估计和二阶矩估计动态调整每个参数的学习率。它的优点主要在于经过偏置校正后,每一次迭代学习率都有个确定范围,使得参数比较平稳。其公式如下:

其中,前两个公式分别是对梯度的一阶矩估计和二阶矩估计,可以看作是对期望 E|gt|,E|g2| 的估计;公式 3、4 是对一阶二阶矩估计的校正,这样可以近似为对期望的无偏估计。可以看出,直接对梯度的矩估计对内存没有额外的要求,而且可以根据梯度进行动态调整。最后一项前面部分是对学习率 n 形成的一个动态约束,而且有明确的范围。

优化器实现:

可使用 torch.optim.Adam 函数直接实现,其参数:params(iterable):可用于迭代优化的参数或者定义参数组的 dicts;lr(float,optional):学习率;

betas(Tuple[float,float],optional):用于计算梯度的平均和平方的系数;eps(float,optional):为了提高数值稳定性而添加到分母的一个项;weight_decay(float,optional):权重衰减(如 L2 惩罚);

3.2 损失函数的选择

选择交叉熵作为损失函数,其公式为:

其中 p(xi)是预测结果,q(xi)是 groundtruth。损失函数实现:可使用 nn.CrossEntropyLoss()函数直接实现。

3.3 数据增强

定义如下的数据增强对训练集数据进行处理:

处理效果如下:

图 3:数据增强效果

3.4 学习率衰减

图 4:学习率衰减曲线

采用学习率递减的算法,初始学习率设为 0.1,其变化为每 10 个 epoch 衰减二分之一,则前 90 个 Epoch 的学习率如图 4 所示.

4 分类结果

在 30000 个数据中取 27000 个作为数据集,3000 个作为验证集,利用上述网络和方法进行分类,结果如图 5。

4.1 曲线图

(a)准确率曲线

(b)损失函数曲线

图 5:分类结果曲线

4.2 准确率曲线分析

图 5(a)是前 90 个 Epoch 的准确率曲线,其中红线代表训练集准确率,蓝线代表验证集准确率,通过观察可以发现,训练集准确率一直呈现上升趋势,而验证集准确率在 40 个 Epoch 之后基本达到稳定,稳定值在 0.89 左右。同时在学习率减半的地方(Epoch10 和 20)处,正确率会发生阶跃的跳变。

4.3 损失函数曲线分析

图 5(b)是前 90 个 Epoch 的损失函数曲线,其中红线代表训练集损失函数,蓝线代表验证集损失函数,通过观察可以发现,训练集损失函数一直呈现下降趋势,而验证集损失函数在 40 个 Epoch 之后基本逐渐升高,可见此时模型出现了过拟合现象。因此佳的训练组数为 40 个 Epoch 左右。同时在学习率减半的地方(Epoch10 和 20)处,损失函数会发生阶跃的跳变。

4.4 主成分分析

图 6:主成分分析结果图分析:用验证集数据对网络的输出层进行主成分分析,可见不同类型的数据分类情况较好,有重叠的原因主要是由于数据降维导致的。

5 实验总结

这是我第一次尝试使用 Pytorch 框架进行卷积神经网络分类任务的搭建,我在写大作业的过程中尝试了许多主流的神经网络,如 AlexNet、ResNet、DenseNet、VGG 等,这一过程虽然耗费了我大量的时间,但是从中我也学习到了许多经典网络的架构及其实现。

在初期的时候,我还不懂得划分数据集与验证集,所以盲目地提交了好多次数据,也给自己调参造成了一定的不便。同时调参我一开始也没有掌握方法,随后我逐渐学会了根据训练集和验证集的准确率和损失曲线来调节,这使我的准确率有了较大的提升。

最终是 Pytorch 框架的学习,掌握了 CNN 的原理并不能让我写出代码,学习这一部分的框架也格外重要,我在 CSDN 上找到了一篇 Pytorch 识别 MNIST 手写数字数据集的样例代码,逐一查找样例代码中各函数的用法,终于学会了这一框架的用法,深感其方便。

总之,虽然这次大作业花费了我很多时间,但我确实从中收获了许多。

♻️ 资源

在这里插入图片描述

大小: 13.0MB
➡️ 资源下载:https://download.csdn.net/download/s1t16/87575180
注:如当前文章或代码侵犯了您的权益,请私信作者删除!

猜你喜欢

转载自blog.csdn.net/s1t16/article/details/131675199