我完全手写的Resnet50网络,终于把猫识别出来了

大家好啊,我是董董灿。

经常看我文章的同学,可能知道最近我在做一个小项目——《从零手写Resnet50实战》。

从零开始,用最简单的程序语言,不借用任何第三方库,完成Resnet50的所有算法实现和网络结构搭建,最终将下面这只猫识别出来。

图片

不幸的是,在刚搭建完网络之后,就试着运行了一下自己的神经网络,识别结果是错的

没办法,只能用 torch 搭了一个官方的网络,和我手写的神经网络,一层一层进行结果比对,然后调试(从零手写Resnet50实战——利用 torch 识别出了虎猫和萨摩耶)。

幸运的是,在经过一个数据一个数据对比之后,我的神经网络。

出猫了!

它竟然真的将猫识别出来了!


我的神经网络出猫现场

在从1000个分类得分中,将最大值的索引(第282号)挑出来之后,查询分类文件,便得到了分类结果:tiger cat

这个网络我没用 softmax。

因为 softmax 的作用是将结果“大的变得更显著,小的变得更微弱”,并不会改变结果的相对大小。

我直接从最后一个全连接层的输出,去找了最大值索引。

softmax的作用可以参考softmax原理

过程记录

识别出猫的过程,说难也不是太难,说简单但又有不少坑。

最难的在于出猫失败,查找原因的过程,是真的一层一层的进行结果对比。

好在我封装了一个对比函数,能帮助我在重要的网络节点,验证我的网络是否正确。

下面是出猫全流程记录。

首层 Conv2d + 第一个 BatchNorm2d + MaxPool 验证正确。

第一个 Layer 验证正确,共 10 个 Conv2d。

第二个 Layer 验证正确,共 13 个 Conv2d。

第三个 Layer 验证正确,共 19 个 Conv2d。

第四个 Layer 验证正确,共10个卷积。

AvgPool 和 FC 层也都验证正确。

整个验证过程还是挺痛苦的,但是看着一层层的打出来“succ”(success的缩写,说明和官方结果是一致的),还是很有成就感,并且挺治愈的。

下面简单说一下

我在出猫过程中遇到的那些坑

保存权值文件 layout 搞错

torch 默认的图片数据摆放格式是 NCHW,而我习惯写算法的方式是NHWC。

因此,在前期将图片导出时,没有考虑自己算法实现的习惯,而是将权值直接 flatten之后保存了。

结果就是,再将权值从文件中读入内存参与运算时,数据读取不正确。结果肯定是错的。

意识到这一点之后,因为我算法都已经写好,而且不想改了,于是,将保存权值的逻辑,在 flatten 之前,添加了一个 transpose 操作,将权值从NCHW 转为 NHWC,然后保存。

BatchNorm2d 的均值和方差使用错误

BatchNorm2d的算法实现有多种,特别需要注意的是需要区分该算法是在训练时用的还是推理时用的。

训练和推理时用的BatchNorm2d虽然公式是一样的,但实现方式却大不一样。

主要区别在于:

  • 训练时均值和方差需要根据本次的数据进行实时计算

  • 推理是使用的均值和方差是模型保存好的参数,在 torch 模型中,分别为 BatchNorm2d.running_mean 和 BatchNorm2d.running_var。

而我在最开始的算法实现时,均值和方差是自己手算的(对应的训练过程),而没有使用模型保存的均值和方差。

结果便是每层BatchNorm算出来的结果都差一点,这一点误差在层与层之间传递,导致到最后的识别结果中,误差被放大。

正是因为这个误差被逐层放大,就把猫识别成了一个水桶

残差结构问题

上一篇文章从零手写Resnet50实战——利用 torch 识别出了虎猫和萨摩耶分析残差结构可能会有问题。

实际验证残差结构没问题,就是一个简单地加法。

有问题的是上一层的BatchNorm2d,计算错误了。

那为什么计算错了,当时的分析仍然能和官方的计算结果对上呢?

是因为当时的官方计算忘了一个 model.eval() 调用。该调用会告诉模型运行在推理模式而不是训练模式。

而如果我不调用,显然用的训练模式,恰巧的 BatchNorm2d 的第一次实现,手算均值和方差,就对应了训练模式的算法。

于是结果刚好对上了,但这样最终识别的图片分类肯定还是错误的。

基本就遇到了这3个问题,在把这3个问题解决了之后,整个预测过程运行了大约40分钟,猫就被顺理成章的预测出来了。

于是,项目的第一阶段,就这么完成了。

下面会开启本项目的第二阶段——神经网络性能优化。

  • 用C++ 重新实现一遍所有算法:因为C++性能要比 python 手写的算法性能好很多

  • 重点优化卷积的性能:目前40多分钟有将近39分钟的实践花在了卷积上

  • 使用C++实现的版本争取在数秒内完成一张图片的推理

因为本次有2/3的坑都是BatchNorm算法引起的,后面会写一篇BatchNorm算法的文章,欢迎继续关注。

出猫,看起来也很简单。

欢迎持续关注本博主文章和本系列,一起从零开始,学算法,做实践项目。
这是一个可以写到简历上,亮瞎面试官双眼的项目哦

v v v v v v

点击下方卡片,关注我的公众号,有最新的文章和项目动态。

v v v v v v

猜你喜欢

转载自blog.csdn.net/dongtuoc/article/details/130333621