我的AI之路(14)--Caffe example:使用MNIST数据集训练和测试LeNet-5模型

       好久以前就体验了,没时间作记录,现在补上。

       MNIST(Mixed National Institute of Standards and Technology)是一个手写数字图形数据库,由Yann LeCun(在深度学习的各种书中经常会被提到的大神之一,CNN的提出者) N年前整理出来的,训练数据集有60000个样本,测试数据集有10000个样本,每张图都已进行了尺寸归一化,固定尺寸为28像素x28像素,常被用作体验训练和测试网络模型的示例数据,好比你刚学C/C++/Java时仿照书上说的尝试写个Hellow Word程序然后编译执行那样偷笑

     1)执行get_mnist.sh下载MNIST数据集的四个文件(以下假设当前目录是在caffe顶级目录下):

         cd data/mnist/

         ./get_mnist.sh

     完成后可以看到train-images-idx3-ubyte、train-labels-idx1-ubyte、t10k-images-idx3-ubyte、t10k-labels-idx1-ubyte这四个文件被下载到当前位置,他们分别是训练数据集(图片,60000条目)、训练数据集(标签,60000条目)、测试数据集(图片,10000条目)、测试数据集(标签,10000条目):

      2)回到caffe顶级目录下,执行

         ./examples/mnist/create_mnist.sh

      在examples/mnist/下生成了mnist_train_lmdb和mnist_test_lmdb两个目录,每个目录下都有data.mdb和lock.mdb两个文件,也就是把训练和测试数据的图片和标签文件里的数据导入到了这两个训练和测试数据库里了,以供后面caffe执行训练和测试脚本时从对应数据库里读取数据。

      3)在caffe顶级目录下执行

         ./examples/mnist/tain_lenet.sh

 train_lenet.sh的内容是:

#!/usr/bin/env sh
set -e
./build/tools/caffe train --solver=examples/mnist/lenet_solver.prototxt $@

其中examples/mnist/lenet_solver.prototxt是定义全网络用到的所谓超参数(Hyper-Parameter)的文件,也就是指定网络模型定义文件、迭代次数、学习率、冲量、学习率衰减gamma值、迭代多少次保存一次训练快照、使用CPU还是GPU计算等参数。

        在我本子上,使用GPU计算配置的话执行2分钟左右即完成了,使用CPU计算配置的话需要5分多钟,可见GPU计算还确实时快了不少,执行完train后的结果日志大致是这样:

可见最终的训练结果精度达到了0.9908,很高的准确率。

训练过程中每迭代5000次保存一次网络各个权值的参数到.caffemodel文件里,同时保存一次训练状态到.solverstate文件里,相当于每迭代5000次就保存一次快照(snapshot):

     4)执行下面的命令,用网络模型定义文件examples/mnist/lenet_train_test.prototxt和刚才迭代10000次训练出的网络模型的参数的存储文件在测试数据集上作测试:

       ./build/tools/caffe test -model examples/mnist/lenet_train_test.prototxt -weights examples/mnist/lenet_iter_10000.caffemodel -iterations 100

测试只需花十几秒就完成了,结果精度也是0.9908:

对于caffe的各种参数的含义,可以在caffe顶级目录下执行./build/tools/caffe即可打印出各个参数的含义:

我的AI之路(1)--前言

我的AI之路(2)--安装Fedora 28

我的AI之路(3)--安装Anaconda3 和Caffe

我的AI之路(4)--在Anaconda3 下安装Tensorflow 1.8

我的AI之路(5)--如何选择和正确安装跟Tensorflow版本对应的CUDA和cuDNN版本

我的AI之路(6)--在Anaconda3 下安装PyTorch

我的AI之路(7)--安装OpenCV3_Python 3.4.1 + Contrib以及PyCharm

我的AI之路(8)--体验用OpenCV 3的ANN进行手写数字识别及解决遇到的问题

我的AI之路(9)--使用scikit-learn

我的AI之路(10)--如何在Linux下安装CUDA和CUDNN

我的AI之路(11)--如何解决在Linux下编译OpenCV3时出现的多个错误

我的AI之路(12)--如何配置Caffe使用GPU计算并解决编译中出现的若干错误

我的AI之路(13)--解决编译gcc/g++源码过程中出现的错误

我的AI之路(14)--Caffe example:使用MNIST数据集训练和测试LeNet-5模型

我的AI之路(15)--Linux下编译OpenCV3的最新版OpenCV3.4.1及错误解决

我的AI之路(16)--云服务器上安装和调试基于Tensorflow 1.10.1的训练环境

我的AI之路(17)--Tensorflow和Caffe的API及Guide

我的AI之路(18)--Tensorflow的模型安装之object_detection

猜你喜欢

转载自blog.csdn.net/XCCCCZ/article/details/80959386
今日推荐