基于OpenCV的 SVM算法实现数字识别(四)---代码实现

三、基于SVM算法实现手写数字识别

作为一个工科生,而非数学专业的学生,我们研究一个算法,是要将它用于实际领域的。下面给出基于OpenCV3.0SVM算法手写数字识别程序源码(参考http://blog.csdn.net/firefight/article/details/6452188)程序略有改动。

本部分将基于OpenCV实现简单的数字识别,待识别图像如下图所示,通过以下几个步骤实现图像中的数字的自动识别。 

                                                             

1.使用OpenCV训练手写数字识别分类器;

2.图像预处理及图像分割;

3.应用分类器进行识别。

3.1使用OpenCV训练手写数字识别分类器

所谓学习分类器就是根据训练样本,选取模型训练产生数字分类器,这里采用上文提到的SVM算法。

训练集使用MNIST,这个MNIST数据库是一个手写数字的数据库,它提供了六万的训练集和一万的测试集。它的图片是被规范处理过的,是一张被放在中间部位的28px*28px的灰度图。总共包含4个文件,每一个文件头部几个字节都记录着这些图片的信息,然后才是储存的图片信息,关于文件信息的具体描述可以参考下面这个网站:https://www.jianshu.com/p/4195577585e6

下面是利用OpenCV 3.2.0SVM相关API学习MNIST样本库产生样本函数的主要代码:(值得注意的是MNIST库中的图像是黑底白字的)

svm.h头文件

#pragma once
#include <stdio.h>
#include <tchar.h>
#include<opencv/cv.h>
#include<opencv/highgui.h>


#include <windows.h>
#include <stdlib.h>
#include <iostream>
using namespace std;
using namespace cv;


class NumTrainData
{
public:
	NumTrainData()
	{
		memset(data, 0, sizeof(data));//Sets buffers to a specified character. Init the data
		result = -1;
	}
public:
	float data[64];
	int result;
};

extern vector<NumTrainData> buffer;

int ReadTrainData(int maxCount);
void newSvmStudy(vector<NumTrainData>& trainData);
char JpgPredict(Mat src);

svm.cpp文件

#include "svm.h"


#include "opencv2/opencv.hpp"

using namespace cv;
using namespace std;
using namespace cv::ml;



#define SHOW_PROCESS 0
#define ON_STUDY 0
int featureLen = 64;

void swapBuffer(char *buf)//0123->3210
{
	char temp;
	temp = *(buf);
	*buf = *(buf + 3);
	*(buf + 3) = temp;

	temp = *(buf + 1);
	*(buf + 1) = *(buf + 2);
	*(buf + 2) = temp;
}

//获取ROI区域
void GetROI(Mat& src, Mat& dst)
{
	int left, right, top, bottom;
	left = src.cols;
	right = 0;
	top = src.rows;
	bottom = 0;//右下角为原点

			   //Get valid area
	for (int i = 0; i < src.rows; i++)
	{
		for (int j = 0; j < src.cols; j++)
		{
			if (src.at<uchar>(i, j) > 0)//获取src中i,j点的像素值,为灰度图像,值为0-255
			{
				if (j < left) left = j;
				if (j > right) right = j;
				if (i < top) top = i;
				if (i > bottom) bottom = i;
			}
		}
	}//将原点置于含有像素点的方框的左上角

	 //Point center;
	 //center.x=(left+right)/2;
	 //center.y=(top+bottom)/2;

	int width = right - left;
	int height = bottom - top;
	int len = (width < height) ? height : width;

	//create a squre
	dst = Mat::zeros(len, len, CV_8UC1);

	//Copy valid data to squre center
	Rect dstRect((len - width) / 2, (len - height) / 2, width, height);
	Rect srcRect(left, top, width, height);
	Mat dstROI = dst(dstRect);
	Mat srcROI = src(srcRect);
	srcROI.copyTo(dstROI);

}

int ReadTrainData(int maxCount)
{
	//Open image and label file
	const char fileName[] = "res//train-images.idx3-ubyte";//图像信息,以二进制方式存储  28*28
	const char LabelFileName[] = "res//train-labels.idx1-ubyte";//标签信息,以二进制方式存储

																//ofstream是从内存到硬盘,ifstream是从硬盘到内存,读取标准样本库
	ifstream lab_ifs(LabelFileName, ios_base::binary);
	ifstream ifs(fileName, ios_base::binary);

	if (ifs.fail() == true)//读取文件失败
		return -1;
		


	if (lab_ifs.fail() == true)//读取文件失败
		return -1;

	//Read train data number and image rows/clos
	char magicNum[4], ccount[4], crows[4], ccols[4];
	ifs.read(magicNum, sizeof(magicNum));//Read block of data
	ifs.read(ccount, sizeof(ccount));
	ifs.read(crows, sizeof(crows));
	ifs.read(ccols, sizeof(ccols));


	int count, rows, cols;
	swapBuffer(ccount);//Copies bytes between buffers.
	swapBuffer(crows);
	swapBuffer(ccols);

	memcpy(&count, ccount, sizeof(count));//Copies bytes between buffers.
	memcpy(&rows, crows, sizeof(rows));
	memcpy(&cols, ccols, sizeof(cols));

	//Just skip label header
	lab_ifs.read(magicNum, sizeof(magicNum));
	lab_ifs.read(ccount, sizeof(ccount));

	//Create source and show image matrix
	Mat src = Mat::zeros(rows, cols, CV_8UC1);//28*28 piex single channel image
	Mat temp = Mat::zeros(8, 8, CV_8UC1);
	Mat img, dst;

	char label = 0;
	Scalar templateColor(255, 0, 255);

	NumTrainData rtd;

	//int loop=1000;
	int total = 0;



	while (!ifs.eof())//Indicates if the end of a stream has been reached.
	{
		if (total >= count)//total train data number
			break;

		total++;
		//cout << total << endl;

		//Read label
		lab_ifs.read(&label, 1);//读取标签,1个字节
		label = label + '0';//转换为ASCII码中的罗马数字

							//Read source data
		ifs.read((char*)src.data, rows*cols);//读取训练图像数据;每个像素被转成了0-255,0代表着白色,255代表着黑色。
		GetROI(src, dst);

#if(SHOW_PROCESS)
		//Too small to watch
		img = Mat::zeros(dst.rows * 10, dst.cols * 10, CV_8UC1);
		resize(dst, img, img.size());

		stringstream ss;
		ss << "Number" << label;
		string text = ss.str();

		putText(img, text, Point(10, 50), FONT_HERSHEY_SIMPLEX, 1.0, template);

#endif

		rtd.result = label;
		resize(dst, temp, temp.size());//将dst缩放成一个8*8的temp矩阵
									   //tehreshold(temp,temp,10,1,CT_THRESH_BINARY);

		for (int i = 0; i < 8; i++)
		{
			for (int j = 0; j < 8; j++)
			{
				rtd.data[i * 8 + j] = temp.at<uchar>(i, j);
			}
		}

		buffer.push_back(rtd);

		//if(waitKey(0)==27)//ESC to quit
		//break;

		maxCount--;

		if (maxCount == 0)
		{
			//cout << "maxcount=" << maxCount << endl;
			system("pause");
			break;
		}

	}

	//buffer中存储了maxcount个8*8的矩阵和它所具有的标签
	ifs.close();
	lab_ifs.close();

	return 0;
}

void newSvmStudy(vector<NumTrainData>& trainData)
{
	int testCount = trainData.size();//60000

	Mat m = Mat::zeros(1, featureLen, CV_32FC1);
	Mat data = Mat::zeros(testCount, featureLen, CV_32FC1);
	Mat res = Mat::zeros(testCount, 1, CV_32SC1);

	for (int i = 0; i < testCount; i++)
	{
		NumTrainData td = trainData.at(i);
		memcpy(m.data, td.data, featureLen * sizeof(float));
		normalize(m, m);
		memcpy(data.data + i*featureLen * sizeof(float), m.data, featureLen * sizeof(float));


		res.at<int>(i, 0) = td.result;

		//res.at<unsigned int>(i, 0) = td.result;//存储标签
	}

	////////////////////START RT TRAINNING///////////////
	//设置SVM参数
	Ptr<SVM> svm = SVM::create();
	svm->setType(SVM::C_SVC);//用于多类分类
	svm->setKernel(SVM::RBF);//采用高斯核函数
	svm->setTermCriteria(cv::TermCriteria(CV_TERMCRIT_EPS, 1000, FLT_EPSILON));
	svm->setDegree(10.0);//高斯核的参数设置
	svm->setGamma(8.0);
	svm->setCoef0(1.0);
	svm->setC(10.0);
	svm->setNu(0.5);
	svm->setP(0.1);

	//训练
	Ptr<TrainData> tData = TrainData::create(data, ROW_SAMPLE, res);
	svm->train(tData);
	svm->save("res\\SVM_DATA.xml");

}

//预测数据
char JpgPredict(Mat src)
{
	Ptr<SVM> svm = Algorithm::load<ml::SVM>("res\\SVM_DATA.xml");
	svm->load("res\\SVM_DATA.xml");

	threshold(src, src, 230, 250, CV_THRESH_BINARY);
	Mat temp = Mat::zeros(8, 8, CV_8UC1);
	Mat m = Mat::zeros(1, featureLen, CV_32FC1);

	Mat element = getStructuringElement(MORPH_RECT, Size(2, 2));
	dilate(src, src, element);

	imshow("1", src);
	waitKey(30);

	resize(src, temp, temp.size());
	
	for (int i = 0; i < 8; i++)
	{
		for (int j = 0; j < 8; j++)
		{
			m.at<float>(0, j + i * 8) = temp.at<uchar>(i, j);
		}
	}

	normalize(m, m);// 该函数归一化输入数组使它的范数或者数值范围在一定的范围内。
	char ret = (char)svm->predict(m);//如果值为true而且是一个2类问题则返回判决函数值,否则返回类标签

	return ret;

}

3.2 图像预处理及图像分割

前面通过学习产生了分类器,但我们输入图像中的数字并不能直接作为测试输入。图像中的数字笔画有时并不规整,还可能相互重叠。因为本文例子为了简化用的是屏幕截图,所以位置形变校正,色彩亮度校正等等都省去了,但仍需要一些简单处理。下面先对输入图像进行一把简单的预处理,主要目的是将图像转成二值图,这样便于我们下一步分割和识别。这样做还有个好处,就是把其余的噪声也顺带去掉了。

接下来,就可以对图像进行分割了。由于我们的分类器只能对数字一个一个地识别,所以首先要把每个数字分割出来。基本思想是先用findContours()函数把基本轮廓找出来,然后通过简单验证以确认是否为数字的轮廓。对于那些通过验证的轮廓,接下去会用boundingRect()找出它们的包围盒。

Process.h文件

#pragma once
#include "svm.h"
#include "opencv2/opencv.hpp"


class Coordinate     //坐标类
{
public:
	double x, y;    //轮廓位置
	int order;      //轮廓向量contours中的第几个

	bool operator<(Coordinate &m)   //运算符重载,在sort()排序函数中使用
	{
		if (x < m.x)
				return true;
			else
				return false;
	}
};

void ImageProcess(Mat &srcImage);
void ImageFindRectangle(Mat &srcImage);

Process.cpp文件

#include "Process.h"



using namespace cv;
using namespace std;

Coordinate con[100] = { 0 }; //存放分割好的矩阵的中心坐标
vector<vector<Point>> contours;//定义一个存放边缘矩阵的容器
vector<Vec4i> hierarchy;  //定义一个存放树节点的前后关系的容器
Rect rect[100];            //定义一个存放分割好图像的矩阵,注意数据溢出关系
int i = 0;//全局变量


void ImageFindRectangle(Mat &srcImage)
{
	//使用contours迭代器遍历每一个轮廓,找到并画出包围这个轮廓的最小矩阵
	vector<vector<Point>>::iterator It;
	for (It = contours.begin(); It < contours.end(); It++)
	{
		//画出可包围数字的最小矩形
		Point2f vertex[4];
		rect[i] = boundingRect(*It);  //计算轮廓的垂直边界最小矩形,矩形是与图像上下边界平行的
									  //矩形左上角的点
		vertex[0] = rect[i].tl();
		//矩形左下角的点
		vertex[1].x = (float)rect[i].tl().x, vertex[1].y = (float)rect[i].br().y;
		//矩形右下角的点
		vertex[2] = rect[i].br();
		//矩形右上方的点
		vertex[3].x = (float)rect[i].br().x, vertex[3].y = (float)rect[i].tl().y;

		for (int j = 0; j < 4; j++)
			line(srcImage, vertex[j], vertex[(j + 1) % 4], Scalar(0, 0, 255), 1);


		con[i].x = (vertex[0].x + vertex[1].x + vertex[2].x + vertex[3].x) / 4.0;
		//根据中心点判断图图像的位置
		con[i].y = (vertex[0].y + vertex[1].y + vertex[2].y + vertex[3].y) / 4.0;
		con[i].order = i;

		i++;

	}



	sort(con, con + i);  //将con按升序排列
}


void ImageProcess(Mat &srcImage)
{
	Mat Image = Mat::zeros(srcImage.size(), CV_8U);
	Mat grayImage = Mat::zeros(srcImage.size(), CV_8U);
	//图像预处理
	cvtColor(srcImage, srcImage, COLOR_BGR2GRAY);   //转化为灰度图像
	threshold(srcImage, srcImage, 230, 255, CV_THRESH_BINARY);//阈值化




	//寻找图像边缘
	findContours(srcImage, contours, hierarchy, CV_RETR_EXTERNAL, CV_CHAIN_APPROX_NONE);//寻找图像边缘;函数用法参数见笔记
	Mat dstImage = Mat::zeros(Image.size(), CV_8U);
	
	drawContours(dstImage, contours, -1, Scalar(255, 0, 255));//在dstImage图像中画出边缘
	//进行分割
	ImageFindRectangle(dstImage);
	//存储分割矩阵
	Mat num[11];
	for (int j = 0; j < i; j++)
	{
		int k;
		k = con[j].order;
		srcImage(rect[k]).copyTo(num[j]);
	}
	cout << "i=" << i << endl;
	vector<char> res;



	for (int j = 0; j < i; j++)
	{
		res.push_back(JpgPredict(num[j]));
		//cout << JpgPredict(num[j]) << endl;
		
	}

	cout << "Predicted number is:";

	for (const auto&number : res)
	{
		cout <<number;
	//	system("pause");
	}

}

3.3 应用分类器进行识别

Main.cpp函数

#include "svm.h"
#include "Process.h"

#include <fstream>
#include <vector>

#include <opencv2/opencv.hpp>

using namespace cv;
using namespace std;

vector<NumTrainData> buffer;

#define ON_STUDY 0
#define ON_PROCESS 1

int main(void)
{
#if ON_STUDY
	int maxCount = 30000;
	ReadTrainData(maxCount);
	newSvmStudy(buffer);
#endif
#if ON_PROCESS
	Mat img = imread("Sample3.jpg");

	ImageProcess(img);
	waitKey(0);
#endif
	return 0;
}



识别结果如下:

                               

结果检测,SVM算法可以较好的识别手写数字,但是在编写代码的过程中发现一个问题,那就是这个算法对“1”数字的识别精度非常差,可能10张图中只能正确识别一次,不知道有没有大神能够给出一些建议?


上一篇:基于OpenCV的 SVM算法实现数字识别(三)---SMO求解



猜你喜欢

转载自blog.csdn.net/LIT_Elric/article/details/79202469