C++调用caffe中训练好的LeNet网络识别数字

1、LeNet模型介绍

LeNet是一种CNN模型,整体框架如下图

可以看出LeNet网络结构规模比较小,但包含了卷积层,pooling层,全连接层等CNN网络的基本组件

第一层:输入层是32*32大小的图像,希望字母的一些特征能够出现在最高层特征检测子感受野的中心

第二层:C1层是一个卷积层,6个feature map,5*5大小的卷积核,每个featuremap有(32-5+1)*(32-5+1)即28*28个神经元,每个神经元都与输入层5*5大小的区域相连,于是C1层共有(5*5+1)*6=156个训练参数。两层之间的连接数为156*(28*28)个。通过卷积运算,使原信号特征增强,并且降低噪音,而且不同的卷积核能够提取到图像中的不同特征。

第三层:S2层是一个下采样层,有6个14*14的featuremap,每个featuremap中的每个神经元都与C1层对应的feature中的2*2的区域相连。下采样层的目的是为了降低网络训练参数及模型的过拟合程度。下采样层的池化方式有两种:1、选择Pooling窗口中的最大值作为采样值的叫最大值池化。2、将Pooling窗口中的所有值相加取平均值作为采样值,即均值池化。

第四层:C3层也是一个卷积层,有16个featuremap,每个featuremap由上一层的各featuremap之间的不同组合

第五层:S4层是一个下采样层,由16个5*5大小的featuremap构成,每个神经元与C3中对应featuremap的2*2大小的区域相连

第六层:C5层又是一个卷积层,同样使用5*5的卷积核。

第七层:F6全连接层有84个featuremap,每个featuremap只有一个神经元与C5层全相连,F6层计算输入向量和权重向量之间的点积和偏置,之后将其传递给sigmoid函数来计算神经元

第八层:输出层也是全连接层,共有10个节点,分别代表数字0到9,如果节点i的值为0,则网络识别的结果是数字i

2、利用Caffe得到训练模型

#caffe目录下
#1、下载数据集
sh ./data/mnist/getmnist.sh

#2、转换格式
sh ./examples/mnist/create_mnist.sh
#这一步是将下载的数据集的二进制格式转换成caffe能识别的lmdb格式
#执行以后,example/mnist目录下出现mnist_train_lmdb和mnist_test_lmdb

#首先确定lenet_train_test.prototxt文件在红的source参数文件路径没有问题(lmdb文件的路径)
#3、训练数据集
sh ./example/mnist/train_lenet.sh
#训练结束后,会出现一个accuracy=0.991的,代表分类准确率为99.1%

此时目录example/mnist下会出现两个重要文件lenet.prototxt、lenet_iter_10000.caffemodel

3、建立C++工程

main.cpp

#include <iostream>
#include <opencv2/opencv.hpp>
#include <caffe/caffe.hpp>
#include <string>
using namespace caffe;
using namespace std;
int main(int argc,char* argv[]) {
    typedef float type;
    type ary[28*28];
 
    //在28*28的图片颜色为RGB(255,255,255)背景上写RGB(0,0,0)数字.
    cv::Mat gray(28,28,CV_8UC1,cv::Scalar(255));
    cv::putText(gray,argv[3],cv::Point(4,22),5,1.4,cv::Scalar(0),2);
 
    //将图像的数值从uchar[0,255]转换成float[0.0f,1.0f],的数, 且颜色取相反的 .
    for(int i=0;i<28*28;i++){
            // f_val =(255-uchar_val)/255.0f
            ary[i] = static_cast<type>(gray.data[i]^0xFF)*0.00390625;   
    }
 
    cv::imshow("x",gray);
    cv::waitKey();
 
    //set cpu running software
    Caffe::set_mode(Caffe::CPU);
 
    //load net file , caffe::TEST 用于测试时使用
    Net<type> lenet(argv[1],caffe::TEST);
 
    //load net train file caffemodel
    lenet.CopyTrainedLayersFrom(argv[2]);
 
 
 
    Blob<type> *input_ptr = lenet.input_blobs()[0];
    input_ptr->Reshape(1,1,28,28);
 
    Blob<type> *output_ptr= lenet.output_blobs()[0];
    output_ptr->Reshape(1,10,1,1);
 
    //copy data from <ary> to <input_ptr>
    input_ptr->set_cpu_data(ary);
 
    //begin once predict
    lenet.Forward();
 
 
    const type* begin = output_ptr->cpu_data();
 
    // get the maximum index
    int index=0;
    for(int i=1;i<10;i++){
        if(begin[index]<begin[i])
        index=i;
    }
 
    // 打印这次预测[0,9]的每一个置信度
    for(int i=0;i<10;i++)
        cout<<i<<"\t"<<begin[i]<<endl;
 
    // 展示最后的预测结果
    cout<<"res:\t"<<index<<"\t"<<begin[index]<<endl;
    return 0;
}

CMakeLists.txt

cmake_minimum_required (VERSION 2.8)
 
project (classification_test)
 
add_executable(classification_test main.cpp)
#SET (SRC_LIST main.cpp)
include_directories ( /home/dzh/caffe/include
    /usr/local/include
    /usr/local/cuda/include
    /usr/include )

#使用opencv
find_package(OpenCV REQUIRED)
target_link_libraries(classification_test ${OpenCV_LIBS} )
 
target_link_libraries(classification_test
    /home/dzh/caffe/build/lib/libcaffe.so
    /usr/lib/x86_64-linux-gnu/libglog.so
    /usr/lib/x86_64-linux-gnu/libboost_system.so
    )

#add_library表示编译时在lib文件夹下会生成libclassification.so文件
#如果需要生成.so文件,则需要将上面SET处解除注释
#add_library(so_test SHARED ${SRC_LIST})

然后可以将生成的两个文件lenet.prototxt、lenet_iter_10000.caffemodel拷贝到build目录下,编译后执行,后面的数字代表生成的具体数字,由于waitkey的作用,需要把显示的28*28图片关掉才会显示预测结果

./caffe_mint lenet.prototxt lenet_iter_10000.caffemodel 1

假如编译过程中出现了

fatal error: caffe/proto/caffe.pb.h: No such file or directory
 #include "caffe/proto/caffe.pb.h"

就用protoc从caffe/src/caffe/proto/caffe.proto生成caffe.pb.h和caffe.pb.cc,执行

protoc --cpp_out=/home/dzh/caffe/include/caffe/ caffe.proto

在caffe/include/caffe/目录下新建一个proto文件夹,将生成的caffe.pb.h和caffe.pb.cc两个文件放到里面

再编译,就可以了

相关的API接口解析在下一篇博客

猜你喜欢

转载自blog.csdn.net/CSDN_dzh/article/details/81980631
今日推荐