手写体数字识别

转自:

本节代码地址:


https://github.com/vic-w/torch-practice/tree/master/mnist

MNIST是手写数字识别的数据库。在深度学习流行的今天,MNIST数据库已经被大家玩坏了。但是用它来学习卷积神经网络是再好不过的了。这一次,我们就用Torch来实现MNIST数据库的识别。

这一次的代码用到了mnist库,如果之前没有安装,可以在命令行键入:

[plain]  view plain  copy
  1. luarocks install mnist  


和往常一样,我们要先包含必要的库

[plain]  view plain  copy
  1. require 'torch'  
  2. require 'nn'  
  3. require 'optim'  
  4. mnist = require 'mnist'  
其中require ’mnist'一句返回了一个mnist的对象。可以用下面两句来获得mnist的图像数据。

[plain]  view plain  copy
  1. fullset = mnist.traindataset()  
  2. testset = mnist.testdataset()  


接下来我们主要关注的就是模型如何建立,因为除了模型之外,其他的代码都是大同小异的。

首先来看一下著名的LeNet网络模型:



这个图是从Caffe那里借来的。其实用Caffe训练LeNet模型真的是又快又好。但我们的目的是为了学习,所以我们还是用Torch再把它实现一遍。

从上面的图上可以看出,这个模型首先使用了Scale层(蓝色),把输入的图片取值缩小到一定范围内。然后是连续两个卷积层conv1和conv2(红色),分别接一个池化层pool1和pool2(黄色),最后是两个全连接层ip1和ip2 (紫色,ip=inner product),将所有的信息归结到最后的Softmax层(蓝色)。其中只有在两个全连接层之间使用了激活层,激活方式是ReLU(Rectified Linear Units)

将这个模型用Torch的代码实现也非常简单。首先仍然是建立一个容器用来存放各种模块。

[plain]  view plain  copy
  1. model = nn.Sequential()  


放入一个reshape模块。因为mnist库的原始图片是储存为1列728个像素的。我们需要把它们变成1通道28*28的一个方形图片。

[plain]  view plain  copy
  1. model:add(nn.Reshape(1, 28, 28))  


接下来要把图片的每个像素除以256再乘以3.2,也就是把像素的取值归一化到0至3.2之间。这相当于Caffe中的Scale模块。

[plain]  view plain  copy
  1. model:add(nn.MulConstant(1/256.0*3.2))  


然后是第一个卷积层,它的参数按顺序分别代表:输入图像是1通道,卷积核数量20,卷积核大小5*5,卷积步长1*1,图像留边0*0

[plain]  view plain  copy
  1. model:add(nn.SpatialConvolutionMM(1, 20, 5, 5, 1, 1, 0, 0))  


一个池化层,它的参数按顺序分别代表:池化大小2*2,步长2*2,图像留边0*0

[plain]  view plain  copy
  1. model:add(nn.SpatialMaxPooling(2, 2 , 2, 2, 0, 0))  


再接一个卷积层和一个池化层,由于上一个卷积层的核的数量是20,所以这时输入图像的通道个数为20

[plain]  view plain  copy
  1. model:add(nn.SpatialConvolutionMM(20, 50, 5, 5, 1, 1, 0, 0))  
  2. model:add(nn.SpatialMaxPooling(2, 2 , 2, 2, 0, 0))  


在接入全连接层之前,我们需要把数据重新排成一列,所以有需要一个reshape模块

[plain]  view plain  copy
  1. model:add(nn.Reshape(4*4*50))  


这个参数为什么是4*4*50即800呢?其实是这样算出来的:我们的输入是1通道28*28的图像,经过第一个卷积层之后变成了20通道24*24的图像。又经过池化层,图像尺寸缩小一半,变为20通道12*12。通过第二个卷积层,变为50通道8*8的图像,又经过池化层缩小一半,变为50通道4*4的图像。所以这其中的像素一共有4*4*50=800个。

接下来是第一个全连接层。输入为4*4*50=800,输出为500

[plain]  view plain  copy
  1. model:add(nn.Linear(4*4*50, 500))  


两个全连接层之间有一个ReLU激活层

[plain]  view plain  copy
  1. model:add(nn.ReLU())  


然后是第二个全连接层,输入是500,输出是10,也就代表了10个数字的输出结果,哪个节点的响应高,结果就定为对应的数字。

[plain]  view plain  copy
  1. model:add(nn.Linear(500, 10))  


最后是一个LogSoftMax层,用来把上一层的响应归一化到0至1之间。

[plain]  view plain  copy
  1. model:add(nn.LogSoftMax())  


模型的建立就完成了。我们还需要一个判定标准。由于我们这一次是要解决分类问题,一般使用nn.ClassNLLCriterion这种类型的标准(Negative Log Likelihood)

[plain]  view plain  copy
  1. criterion = nn.ClassNLLCriterion()  


为了要达到更好的优化效果,这里需要对model内部参数的初始化做一下特殊的处理。还记得torch会帮我们随机初始化参数吗?我们现在不使用torch的初始化参数,而使用一种更高级的初始化方法,称之为xavier方法。概括来讲,就是根据每层的输入个数和输出个数来决定参数随机初始化的分布范围。在代码里只需要一句:

[plain]  view plain  copy
  1. model = require('weight-init')(model, 'xavier')  


其中的‘weight-init’指向了与主文件同一文件夹里的weight-init.lua这个文件。xavier方法就在这个文件里面。它是由 https://github.com/e-lab/torch-toolbox 所实现的。

到这里,网络模型的部分就都已经完成了。我们现在就需要建立评估函数,然后循环迭代就可以了。这些都是例行公事,可以参照前面的代码来写,这里就不在赘述了。

完整的代码在这里,大家可以运行试一试。

[plain]  view plain  copy
  1. require 'torch'  
  2. require 'nn'  
  3. require 'optim'  
  4. --require 'cunn'  
  5. --require 'cutorch'  
  6. mnist = require 'mnist'  
  7.   
  8. fullset = mnist.traindataset()  
  9. testset = mnist.testdataset()  
  10.   
  11. trainset = {  
  12.     size = 50000,  
  13.     data = fullset.data[{{1,50000}}]:double(),  
  14.     label = fullset.label[{{1,50000}}]  
  15. }  
  16.   
  17. validationset = {  
  18.     size = 10000,  
  19.     data = fullset.data[{{50001,60000}}]:double(),  
  20.     label = fullset.label[{{50001,60000}}]  
  21. }  
  22.   
  23. trainset.data = trainset.data - trainset.data:mean()  
  24. validationset.data = validationset.data - validationset.data:mean()  
  25.   
  26. model = nn.Sequential()  
  27. model:add(nn.Reshape(1, 28, 28))  
  28. model:add(nn.MulConstant(1/256.0*3.2))  
  29. model:add(nn.SpatialConvolutionMM(1, 20, 5, 5, 1, 1, 0, 0))  
  30. model:add(nn.SpatialMaxPooling(2, 2 , 2, 2, 0, 0))  
  31. model:add(nn.SpatialConvolutionMM(20, 50, 5, 5, 1, 1, 0, 0))  
  32. model:add(nn.SpatialMaxPooling(2, 2 , 2, 2, 0, 0))  
  33. model:add(nn.Reshape(4*4*50))  
  34. model:add(nn.Linear(4*4*50, 500))  
  35. model:add(nn.ReLU())  
  36. model:add(nn.Linear(500, 10))  
  37. model:add(nn.LogSoftMax())  
  38.   
  39. model = require('weight-init')(model, 'xavier')  
  40.   
  41. criterion = nn.ClassNLLCriterion()  
  42.   
  43. --model = model:cuda()  
  44. --criterion = criterion:cuda()  
  45. --trainset.data = trainset.data:cuda()  
  46. --trainset.label = trainset.label:cuda()  
  47. --validationset.data = validationset.data:cuda()  
  48. --validationset.label = validationset.label:cuda()  
  49.   
  50. sgd_params = {  
  51.    learningRate = 1e-2,  
  52.    learningRateDecay = 1e-4,  
  53.    weightDecay = 1e-3,  
  54.    momentum = 1e-4  
  55. }  
  56.   
  57. x, dl_dx = model:getParameters()  
  58.   
  59. step = function(batch_size)  
  60.     local current_loss = 0  
  61.     local count = 0  
  62.     local shuffle = torch.randperm(trainset.size)  
  63.     batch_size = batch_size or 200  
  64.     for t = 1,trainset.size,batch_size do  
  65.         -- setup inputs and targets for this mini-batch  
  66.         local size = math.min(t + batch_size - 1, trainset.size) - t  
  67.         local inputs = torch.Tensor(size, 28, 28)--:cuda()  
  68.         local targets = torch.Tensor(size)--:cuda()  
  69.         for i = 1,size do  
  70.             local input = trainset.data[shuffle[i+t]]  
  71.             local target = trainset.label[shuffle[i+t]]  
  72.             -- if target == 0 then target = 10 end  
  73.             inputs[i] = input  
  74.             targets[i] = target  
  75.         end  
  76.         targets:add(1)  
  77.         local feval = function(x_new)  
  78.             -- reset data  
  79.             if x ~= x_new then x:copy(x_new) end  
  80.             dl_dx:zero()  
  81.   
  82.             -- perform mini-batch gradient descent  
  83.             local loss = criterion:forward(model:forward(inputs), targets)  
  84.             model:backward(inputs, criterion:backward(model.output, targets))  
  85.   
  86.             return loss, dl_dx  
  87.         end  
  88.   
  89.         _, fs = optim.sgd(feval, x, sgd_params)  
  90.   
  91.         -- fs is a table containing value of the loss function  
  92.         -- (just 1 value for the SGD optimization)  
  93.         count = count + 1  
  94.         current_loss = current_loss + fs[1]  
  95.     end  
  96.   
  97.     -- normalize loss  
  98.     return current_loss / count  
  99. end  
  100.   
  101. eval = function(dataset, batch_size)  
  102.     local count = 0  
  103.     batch_size = batch_size or 200  
  104.       
  105.     for i = 1,dataset.size,batch_size do  
  106.         local size = math.min(i + batch_size - 1, dataset.size) - i  
  107.         local inputs = dataset.data[{{i,i+size-1}}]--:cuda()  
  108.         local targets = dataset.label[{{i,i+size-1}}]:long()--:cuda()  
  109.         local outputs = model:forward(inputs)  
  110.         local _, indices = torch.max(outputs, 2)  
  111.         indices:add(-1)  
  112.         local guessed_right = indices:eq(targets):sum()  
  113.         count = count + guessed_right  
  114.     end  
  115.   
  116.     return count / dataset.size  
  117. end  
  118.   
  119. max_iters = 30  
  120.   
  121. do  
  122.     local last_accuracy = 0  
  123.     local decreasing = 0  
  124.     local threshold = 1 -- how many deacreasing epochs we allow  
  125.     for i = 1,max_iters do  
  126.         local loss = step()  
  127.         print(string.format('Epoch: %d Current loss: %4f', i, loss))  
  128.         local accuracy = eval(validationset)  
  129.         print(string.format('Accuracy on the validation set: %4f', accuracy))  
  130.         if accuracy < last_accuracy then  
  131.             if decreasing > threshold then break end  
  132.             decreasing = decreasing + 1  
  133.         else  
  134.             decreasing = 0  
  135.         end  
  136.         last_accuracy = accuracy  
  137.     end  
  138. end  
  139.   
  140. testset.data = testset.data:double()  
  141. eval(testset)  

猜你喜欢

转载自blog.csdn.net/ccccccod/article/details/79055565