opencv3 svm

#include<iostream>
#include<opencv2/opencv.hpp>
#include<opencv2/ml.hpp>
#include<fstream>
using namespace std;
using namespace cv;
using namespace ml;
int main(int argc, char** argv)
{
	Mat train_img, train_lables;
	string imgName;
	fstream pos_img("pos_train.txt");
	fstream neg_img("neg_train.txt");
	int i = 0;
	Mat plate_img,noplate_img;
	while (getline(pos_img, imgName))
	{
		cout << "Read" << imgName << endl;
		imgName = "E:\\OpencvVideo\\Plate_recongnize\\EasyPR-master\\resources\\train\\svm训练\\has\\train\\" + imgName;
		plate_img = imread(imgName,0);
		equalizeHist(plate_img, plate_img);
		if (plate_img.empty())
		{
			cout << "Can not find img..." << endl;
		}
		train_img.push_back(plate_img.reshape(1, 1));
		train_lables.push_back(1);
		i++;
	}
	
	int k = 0;
	while (getline(neg_img, imgName))
	{
		cout << "Read" << imgName << endl;
		imgName = "E:\\OpencvVideo\\Plate_recongnize\\EasyPR-master\\resources\\train\\svm训练\\no\\train\\" + imgName;
		noplate_img = imread(imgName,0);
		equalizeHist(noplate_img, noplate_img);
		if (noplate_img.empty())
		{
			cout << "Can not find img..." << endl;
		}
		train_img.push_back(noplate_img.reshape(1, 1));
		train_lables.push_back(0);
		k++;
	}
	train_img.convertTo(train_img, CV_32FC1);
	train_lables.convertTo(train_lables, CV_32SC1);
	Ptr<SVM> my_svm = SVM::create();
	my_svm->setKernel(SVM::RBF);
	my_svm->setC(1);
	my_svm->setGamma(0.03);
	my_svm->setCoef0(0.1);
	my_svm->setNu(0.1);
	my_svm->setP(0.1);
	my_svm->setTermCriteria(TermCriteria(TermCriteria::MAX_ITER, 10000, 0.0001));
	Ptr<TrainData> train_data = TrainData::create(train_img, ROW_SAMPLE, train_lables);
	double t = (double)getTickCount();// Time
	my_svm->train(train_data);
	t = (double)(getTickCount() - t) / getTickFrequency();
	cout << "Time" << t << endl;
	cout << "Train Done" << endl;
	//Save model
	my_svm->save("svm_heibai.xml");
	cout << "Save success" << endl;

	//read model
	/*my_svm->load("svm_chepai.xml");
	Ptr<SVM> svm = SVM::load("chepai_svm77.xml");*/

	//predict
	Mat test_img, test_lables;
	fstream test_pos_name("test_pos_img.txt");
	fstream test_neg_name("test_neg_img.txt");
	string test_img_name;
	
	//read data
	while (getline(test_pos_name, test_img_name))
	{
		cout << "Read" << test_img_name << endl;
		test_img_name = "E:\\OpencvVideo\\Plate_recongnize\\EasyPR-master\\resources\\train\\svm训练\\has\\test\\" + test_img_name;
		plate_img = imread(test_img_name,0);
		equalizeHist(plate_img, plate_img);
		//resize(plate_img, plate_img, Size(20, 20));
		if (plate_img.empty())
		{
			cout << "Can not find img..." << endl;
		}
		test_img.push_back(plate_img.reshape(1, 1));
		test_lables.push_back(1);
	}

	while (getline(test_neg_name, test_img_name))
	{
		cout << "Read" << test_img_name << endl;
		test_img_name = "E:\\OpencvVideo\\Plate_recongnize\\EasyPR-master\\resources\\train\\svm训练\\no\\test\\" + test_img_name;
		plate_img = imread(test_img_name,0);
		equalizeHist(plate_img, plate_img);
		//resize(plate_img, plate_img, Size(20, 20));
		if (plate_img.empty())
		{
			cout << "Can not find img..." << endl;
		}
		test_img.push_back(plate_img.reshape(1, 1));
		test_lables.push_back(0);
	}
	test_img.convertTo(test_img, CV_32FC1);//更改格式
	test_lables.convertTo(test_lables, CV_32SC1);
	cout << "Test data ready" << endl;
	Mat samples;
	int count = 0;
	for (int i = 0; i < test_img.rows; i++)
	{
		samples = test_img.row(i);//每一行进行预测
		int r = my_svm->predict(test_img.row(i));//预测
		int testlables = test_lables.at<int>(i, 0);
		if (static_cast<int>(r) == testlables)
			count++;
	}
	cout << count << endl;
	cout << test_img.rows << endl;
	float result = static_cast<float>(count) / static_cast<float>(test_img.rows);
	cout << result << endl;

	/*Mat samples;*/
	
	//samples = imread("E:/picture/teengirl.jpg");
	//samples.convertTo(samples, CV_32FC1);
	//int r = svm->predict(samples.reshape(1,1));//预测
	//cout << r<< endl;
	system("pause");
	waitKey(0);
	return 0;
}

//预测数据的维数必须要和训练的维数一样。

猜你喜欢

转载自blog.csdn.net/xiexu911/article/details/80005149