Opencv2利用svm训练自己图片进行数字识别

版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/yc5300891/article/details/82261814

了解SVM:https://www.jianshu.com/p/61849d554001

1、获取样本,对自己的样本分类命名,可搜索批量命名方式进行批量命名。

注意样本分辨率保持一致

2、获取训练图像并贴上标签

样本示例:

代码讲解:

void get_0(Mat& trainingImages, vector<int>& trainingLabels)
{
	for (int i = 0; i < 408; i++)//数字0的样本数量为408
	{
		Mat  SrcImage = imread("E:\\VSpro\\svm\\0\\" + to_string(i) + ".jpg", 0);//读样本
		threshold(SrcImage, SrcImage, 0, 255, CV_THRESH_OTSU + CV_THRESH_BINARY);//二值化
		SrcImage = SrcImage.reshape(0, 1);//将图片转成一行
		trainingImages.push_back(SrcImage);//存入训练集
		trainingLabels.push_back(0);//标签为数字0
	}
}

3、配置SVM训练器参数训练并保存模型

OpenCV中的SVM参数优化https://www.cnblogs.com/hust-yingjie/p/6582218.html

    //配置SVM训练器参数
	CvSVMParams params;
	params.svm_type = SVM::C_SVC;
	params.kernel_type = SVM::LINEAR;//RBF效果不好
	params.term_crit = cvTermCriteria(CV_TERMCRIT_ITER, 10000, 1e-6);
	//训练
	CvSVM svm;
	cout << "训练中..." << endl;
	svm.train_auto(trainingData, classes, Mat(), Mat(), params);
	//保存模型
	svm.save("E:\\VSpro\\svm\\svm.xml");

4、验证识别效果

    
char onenumber(Mat img)//分辨率与样本一致,二值化图像
{
    CvSVM svm;
    svm.clear();
	svm.load("E:\\VSpro\\svm\\svm.xml");

	Mat p = img.reshape(0, 1);
	p.convertTo(p, CV_32FC1);
	char response = (int)svm.predict(p);

	p.release();
	return response;
}

两位数根据X坐标高低判断高低位

附训练代码:

#include <cstdio>
#include <iostream>
#include <opencv2/opencv.hpp>
#include <stdio.h>
#include <string.h>
using namespace std;
using namespace cv;

void get_0(Mat& trainingImages, vector<int>& trainingLabels);
void get_1(Mat& trainingImages, vector<int>& trainingLabels);
void get_2(Mat& trainingImages, vector<int>& trainingLabels);
void get_3(Mat& trainingImages, vector<int>& trainingLabels);
void get_4(Mat& trainingImages, vector<int>& trainingLabels);
void get_5(Mat& trainingImages, vector<int>& trainingLabels);
void get_6(Mat& trainingImages, vector<int>& trainingLabels);
void get_7(Mat& trainingImages, vector<int>& trainingLabels);
void get_8(Mat& trainingImages, vector<int>& trainingLabels);
void get_9(Mat& trainingImages, vector<int>& trainingLabels);

int main()
{
	//获取训练数据
	Mat classes;
	Mat trainingData;
	Mat trainingImages;
	vector<int> trainingLabels;
	get_0(trainingImages, trainingLabels);
	get_1(trainingImages, trainingLabels);
	get_2(trainingImages, trainingLabels);
	get_3(trainingImages, trainingLabels);
	get_4(trainingImages, trainingLabels);
	get_5(trainingImages, trainingLabels);
	get_6(trainingImages, trainingLabels);
	get_7(trainingImages, trainingLabels);
	get_8(trainingImages, trainingLabels);
	get_9(trainingImages, trainingLabels);
	Mat(trainingImages).copyTo(trainingData);
	trainingData.convertTo(trainingData, CV_32FC1);
	Mat(trainingLabels).copyTo(classes);
	//配置SVM训练器参数
	CvSVMParams params;
	params.svm_type = SVM::C_SVC;
	params.kernel_type = SVM::LINEAR;//RBF效果不好
	params.term_crit = cvTermCriteria(CV_TERMCRIT_ITER, 10000, 1e-6);
	//训练
	CvSVM svm;
	cout << "训练中..." << endl;
	svm.train_auto(trainingData, classes, Mat(), Mat(), params);
	//保存模型
	svm.save("E:\\VSpro\\vsm\\svm.xml");
	cout << "训练好了!!!" << endl;
	return 0;
}

void get_0(Mat& trainingImages, vector<int>& trainingLabels)
{
	for (int i = 0; i < 408; i++)
	{
		Mat  SrcImage = imread("E:\\VSpro\\svm\\0\\" + to_string(i) + ".jpg", 0);
		threshold(SrcImage, SrcImage, 0, 255, CV_THRESH_OTSU + CV_THRESH_BINARY);
		SrcImage = SrcImage.reshape(0, 1);
		trainingImages.push_back(SrcImage);
		trainingLabels.push_back(0);
	}
}
void get_1(Mat& trainingImages, vector<int>& trainingLabels)
{
	for (int i = 0; i < 1127; i++)
	{
		Mat  SrcImage = imread("E:\\VSpro\\svm\\1\\" + to_string(i) + ".jpg", 0);
		threshold(SrcImage, SrcImage, 0, 255, CV_THRESH_OTSU + CV_THRESH_BINARY);
		SrcImage = SrcImage.reshape(0, 1);
		trainingImages.push_back(SrcImage);
		trainingLabels.push_back(1);
	}
}
void get_2(Mat& trainingImages, vector<int>& trainingLabels)
{
	for (int i = 0; i < 1218; i++)
	{
		Mat  SrcImage = imread("E:\\VSpro\\svm\\2\\" + to_string(i) + ".jpg", 0);
		threshold(SrcImage, SrcImage, 0, 255, CV_THRESH_OTSU + CV_THRESH_BINARY);
		SrcImage = SrcImage.reshape(0, 1);
		trainingImages.push_back(SrcImage);
		trainingLabels.push_back(2);
	}
}
void get_3(Mat& trainingImages, vector<int>& trainingLabels)
{
	for (int i = 0; i < 1188; i++)
	{
		Mat  SrcImage = imread("E:\\VSpro\\svm\\3\\" + to_string(i) + ".jpg", 0);
		threshold(SrcImage, SrcImage, 0, 255, CV_THRESH_OTSU + CV_THRESH_BINARY);
		SrcImage = SrcImage.reshape(0, 1);
		trainingImages.push_back(SrcImage);
		trainingLabels.push_back(3);
	}
}
void get_4(Mat& trainingImages, vector<int>& trainingLabels)
{
	for (int i = 0; i < 1133; i++)
	{
		Mat  SrcImage = imread("E:\\VSpro\\svm\\4\\" + to_string(i) + ".jpg", 0);
		threshold(SrcImage, SrcImage, 0, 255, CV_THRESH_OTSU + CV_THRESH_BINARY);
		SrcImage = SrcImage.reshape(0, 1);
		trainingImages.push_back(SrcImage);
		trainingLabels.push_back(4);
	}
}
void get_5(Mat& trainingImages, vector<int>& trainingLabels)
{
	for (int i = 0; i < 1109; i++)
	{
		Mat  SrcImage = imread("E:\\VSpro\\svm\\5\\" + to_string(i) + ".jpg", 0);
		threshold(SrcImage, SrcImage, 0, 255, CV_THRESH_OTSU + CV_THRESH_BINARY);
		SrcImage = SrcImage.reshape(0, 1);
		trainingImages.push_back(SrcImage);
		trainingLabels.push_back(5);
	}
}
void get_6(Mat& trainingImages, vector<int>& trainingLabels)
{
	for (int i = 0; i < 584; i++)
	{
		Mat  SrcImage = imread("E:E:\\VSpro\\svm\\6\\" + to_string(i) + ".jpg", 0);
		threshold(SrcImage, SrcImage, 0, 255, CV_THRESH_OTSU + CV_THRESH_BINARY);
		SrcImage = SrcImage.reshape(0, 1);
		trainingImages.push_back(SrcImage);
		trainingLabels.push_back(6);
	}
}
void get_7(Mat& trainingImages, vector<int>& trainingLabels)
{
	for (int i = 0; i < 416; i++)
	{
		Mat  SrcImage = imread("E:E:\\VSpro\\svm\\7\\" + to_string(i) + ".jpg", 0);
		threshold(SrcImage, SrcImage, 0, 255, CV_THRESH_OTSU + CV_THRESH_BINARY);
		SrcImage = SrcImage.reshape(0, 1);
		trainingImages.push_back(SrcImage);
		trainingLabels.push_back(7);
	}
}
void get_8(Mat& trainingImages, vector<int>& trainingLabels)
{
	for (int i = 0; i < 374; i++)
	{
		Mat  SrcImage = imread("E:\\VSpro\\svm\\8\\" + to_string(i) + ".jpg", 0);
		threshold(SrcImage, SrcImage, 0, 255, CV_THRESH_OTSU + CV_THRESH_BINARY);
		SrcImage = SrcImage.reshape(0, 1);
		trainingImages.push_back(SrcImage);
		trainingLabels.push_back(8);
	}
}
void get_9(Mat& trainingImages, vector<int>& trainingLabels)
{
	for (int i = 0; i < 421; i++)
	{
		Mat  SrcImage = imread("E:\\VSpro\\svm\\9\\" + to_string(i) + ".jpg", 0);
		threshold(SrcImage, SrcImage, 0, 255, CV_THRESH_OTSU + CV_THRESH_BINARY);
		SrcImage = SrcImage.reshape(0, 1);
		trainingImages.push_back(SrcImage);
		trainingLabels.push_back(9);
	}
}


猜你喜欢

转载自blog.csdn.net/yc5300891/article/details/82261814