MNIST的手写数字识别

版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/shuiliusheng/article/details/79763136

MNIST的手写数字识别

  1. 数据集

    • 传统的手写数字的数据集MNIST(http://yann.lecun.com/exdb/mnist/
    • 训练集为60000图片,图片像素为28*28。Images文件中存储图片,label文件中存储对应的数字
    • 测试集为10000图片,格式一致
      这里写图片描述
  2. 网络结构设计

    ​ 网络结构为三层神经网络,包括一个输入层,一个输出层,两个隐藏层。输入层为图片的向量形式,长度为784,即输入层包括784个神经元。两个隐藏层的个数分别为200和50。输出层的神经元个数为10个。预测结果从10个神经元的输出中选择最大值所在的位置作为结果。

  3. 激活函数的选择

    激活函数选择tanh。该函数的特点:

    • 输出在(-1,1)之间,并且在这个范围内tanh的变化最为剧烈,能够很好的将最小的差异变大,从而提高分类准确性

    • 对于三层网络,使用tanh能够保证最后一层的结果的范围仍然较大,而不会被缩减为很小的一个区间。(sigmoid在测试中出现了这样的问题)

    • 表达式

      t a n h ( z )   = e x p ( z ) e x p ( z ) e x p ( z )   + e x p ( x ) =   2 δ ( 2 z ) 1 ( δ s i g m o i d ) t a n h ( z ) = 1 ( e x p ( z ) e x p ( z ) e x p ( z )   + e x p ( x ) ) 2 = 1 t a n h 2 ( z )

    • 函数和导数图像
      这里写图片描述

  4. 损失函数的选择

    使用MSE作为最后的误差函数,同时增加了正则化项,减少网络中的参数W的值,具体的公式为

    J ( W , B ) = 1 2 p = 0 n ( A p O U T p ) 2 + λ 2 W i , j 2 J ( W , B ) A p = ( A p O U T p ) 2

  5. 网络参数的更新

    对于最后一层:f为激活函数,Z为神经元的输入,A为神经元输出

    δ p ( L ) = J ( W , B ) A p ( L ) f ( z p ( L ) ) J ( W , B ) w p q ( L ) = δ p ( L ) a p ( L 1 ) J ( W , B ) b ( L ) = δ p ( L )

    对于其他层:
    δ j ( l ) = k = 1 n l + 1 δ k ( l + 1 ) w k j ( l + 1 ) f ( z j ( l ) ) J ( W , B ) w j i ( l ) = δ j ( l ) a i ( l 1 ) J ( W , B ) b i ( l ) = δ j ( l )  

    W的更新:
    w j i = w j i - η * J ( W , B ) w j i - λ * w j i  

  6. 原始数据的预处理

    • 输入数据的处理

      首先由于激活函数在(-2,2)之间变化较为明显,而输入图片的每一个像素点为(0,256),因此将输入数据进行归一化处理,转换为(0,1)之间的数值进行输入。

    • 输出数据的处理

      同样是由于激活函数的影响,需要对实际的label进行一定的变换。激活函数的输出为(-1,1),因此设计输出为10维向量,在向量中数字最大的位置即为输出。通过这样的设计能够将相近的数据划分的更开,提高分类的准确率。

  7. 网络参数初始化

    方法:随机初始化为(-0.33,0.33)

    W[i][j] = (rand () %2000-1000)/3000.0

    解释:

    • 输入数据为(0,1)之间,但是向量长度为784,因此如果参数初始化不够小,则通过计算累加和,神经元的输入将变为很大。通过激活函数计算后结果仍旧将都是-1或者是1。因此神经元的第一层输出很可能都变成类似的数字,这并不是想要的结果。(sigmoid会有这样的问题出现)
    • 初始化有负数是因为激活函数的变化较大的区间在(-2,2)之间,因此需要使得神经元的输入尽可能负数和正数都有
  8. 测试结果

    η ( 0.15 , 0.2 , 0.25 , 0.3 , 0.35 , 0.4 ) λ ( 0.01 , 0.005 , 0.001 , 0.0001 ) b a t c h s i z e ( 6 , 64 , 128 , 256 )

    这里写图片描述

github:https://github.com/Shuiliusheng/2018/tree/master/C-_for_NN/three_layers_nn_mnist

猜你喜欢

转载自blog.csdn.net/shuiliusheng/article/details/79763136