该程序利用opencv3的svm训练二分类例程,特征检测只是用了统计像素值这一简单的方法,后期完全可以自己改。亲测可用,效果其实还可以。可能我的识别的物体比较简单。
参考博客:https://blog.csdn.net/chaipp0607/article/details/68067098#commentsedit
他的程序用的opencv2,我改成opencv3了,这样大家就可以更快的学习svm了。至于具体原理,以后有时间再写,可以参考原博主的写的比较详细。
#include <stdio.h>
#include <time.h>
#include <opencv2/opencv.hpp>
#include <opencv/cv.h>
#include <iostream>
#include <opencv2/core/core.hpp>
#include <opencv2/highgui/highgui.hpp>
#include <opencv2/ml/ml.hpp>
#include <io.h> //查找文件相关函数
using namespace std;
using namespace cv;
using namespace cv::ml;
void getFiles(string path, vector<string>& files);
void getBubble(Mat& trainingImages, vector<int>& trainingLabels);
void getNoBubble(Mat& trainingImages, vector<int>& trainingLabels);
int main()
{
//获取训练数据
Mat classes;
Mat trainingData;
Mat trainingImages;
vector<int> trainingLabels;
//getBubble()与getNoBubble()将获取一张图片后会将图片(特征)写入
// 到容器中,紧接着会将标签写入另一个容器中,这样就保证了特征
// 和标签是一一对应的关系push_back(0)或者push_back(1)其实就是
// 我们贴标签的过程。
getBubble(trainingImages, trainingLabels);
getNoBubble(trainingImages, trainingLabels);
//在主函数中,将getBubble()与getNoBubble()写好的包含特征的矩阵拷贝给trainingData,将包含标签的vector容器进行类
//型转换后拷贝到trainingLabels里,至此,数据准备工作完成,trainingData与trainingLabels就是我们要训练的数据。
Mat(trainingImages).copyTo(trainingData);
trainingData.convertTo(trainingData, CV_32FC1);
Mat(trainingLabels).copyTo(classes);
//classes.convertTo(classes, CV_32SC1);
// 创建分类器并设置参数
Ptr<SVM> SVM_params = SVM::create();
SVM_params->setType(SVM::C_SVC);
SVM_params->setKernel(SVM::LINEAR); //核函数
SVM_params->setDegree(0);
SVM_params->setGamma(1);
SVM_params->setCoef0(0);
SVM_params->setC(1);
SVM_params->setNu(0);
SVM_params->setP(0);
SVM_params->setTermCriteria(TermCriteria(TermCriteria::MAX_ITER + TermCriteria::EPS, 1000, 0.01));
Ptr<TrainData> tData = TrainData::create(trainingData, ROW_SAMPLE, classes);
// 训练分类器
SVM_params->train(tData);
//保存模型
SVM_params->save("svm.xml");
cout << "训练好了!!!" << endl;
getchar();
return 0;
}
void getFiles(string path, vector<string>& files)
{
intptr_t hFile = 0;
struct _finddata_t fileinfo;
string p;
int i = 30;
if ((hFile = _findfirst(p.assign(path).append("\\*").c_str(), &fileinfo)) != -1)
{
do
{
if ((fileinfo.attrib & _A_SUBDIR))
{
if (strcmp(fileinfo.name, ".") != 0 && strcmp(fileinfo.name, "..") != 0)
getFiles(p.assign(path).append("\\").append(fileinfo.name), files);
}
else
{
files.push_back(p.assign(path).append("\\").append(fileinfo.name));
}
}
while (_findnext(hFile, &fileinfo) == 0);
_findclose(hFile);
}
}
//获取正样本
//并贴标签为1
void getBubble(Mat& trainingImages, vector<int>& trainingLabels)
{
char * filePath = "E:\\SVM_train_data\\positive\\train"; //正样本路径
vector<string> files;
getFiles(filePath, files);
int number = files.size();
for (int i = 0; i < number; i++)
{
Mat SrcImage = imread(files[i].c_str());
SrcImage = SrcImage.reshape(1, 1);
trainingImages.push_back(SrcImage);
trainingLabels.push_back(1);//该样本为数字5
}
}
//获取负样本
//并贴标签为0
void getNoBubble(Mat& trainingImages, vector<int>& trainingLabels)
{
//char * filePath = "D:\\train\\no\\train"; //负样本路径
//char * filePath = "E:\\OCR_Recognition\\opencv_project\\SVM_train_data\\negative\\train"; //负样本路径
char * filePath = "E:\\SVM_train_data\\negative\\train"; //负样本路径
vector<string> files;
getFiles(filePath, files);
int number = files.size();
for (int i = 0; i < number; i++)
{
Mat SrcImage = imread(files[i].c_str());
SrcImage = SrcImage.reshape(1, 1);
trainingImages.push_back(SrcImage);
trainingLabels.push_back(0); //该样本不是数字5
}
}