OpenCV:利用OpenCV2.4.9进行SVM图片分类测试

一、对不同红色图片进行分类,图片名称为预设颜色,有三种:深红色、粉红色、橘红色,分别对应标签名称"crimson", "pink","tangerine"。

二、利用图片的RGB所占比例作为三个特征,对红色图片进行SVM训练,并进行预测分类。训练所用图片和要预测的图片分布放在两个文件夹中。

       

三、代码【用到了Qt】分享给有需要的人,代码质量勿喷。

void xjImage::xjSVMtest2()//红色分类
{
#pragma region 数据情况
	const string xjSampleFolder = "E:/SVM/ImgColor/train";//样本数据文件夹
	vector<string> xjSampleFiles;//用来存储文件名
	SVMfunction * xjSvmFun = new SVMfunction();
	xjSvmFun->xjGetFilesNames(xjSampleFolder, xjSampleFiles);
	const int xjSampleSum = xjSampleFiles.size();//训练样本数量

	string xjLabelName[3] = { "crimson", "pink","tangerine" };//类别名称
	const int xjClassSum = 3;//样本类别数量
	const int xjFeatureSum = 3;//样本特征数量:RGB的均值

	//训练数据及标签
	Mat xjMatSampleTrain = Mat::zeros(xjSampleSum, xjFeatureSum, CV_32FC1);//据说好像必须是CV_32FC1
	Mat xjMatSampleLabel = Mat::zeros(xjSampleSum, 1, CV_32SC1);
#pragma endregion

#pragma region 创建训练数据:3个特征和标签
	for (int i = 0; i < xjSampleSum; i++)
	{
		string xjTrainName = xjSampleFiles[i];
		Mat xjMatTrainData = imread(xjTrainName);
		float Pred = 0, Pgreen = 0, Pblue = 0;
		int red, green, blue;
		float rgbSum = 0;
		for (int r = 1; r < xjMatTrainData.rows; r++)
		{
			for (int c = 1; c < xjMatTrainData.cols; c++)
			{
				red = xjMatTrainData.at<Vec3b>(r, c)[2];
				green = xjMatTrainData.at<Vec3b>(r, c)[1];
				blue = xjMatTrainData.at<Vec3b>(r, c)[0];
				rgbSum = red + green + blue;
				Pred += red / rgbSum;
				Pgreen += green / rgbSum;
				Pblue += blue / rgbSum;
			}
		}
		Pred /= (xjMatTrainData.rows*xjMatTrainData.cols);
		Pgreen /= (xjMatTrainData.rows*xjMatTrainData.cols);
		Pblue /= (xjMatTrainData.rows*xjMatTrainData.cols);

		xjMatSampleTrain.at<float>(i, 0) = Pred;//特征1
		xjMatSampleTrain.at<float>(i, 1) = Pgreen;//特征2
		xjMatSampleTrain.at<float>(i, 2) = Pblue;//特征3
		int xjLabel = xjSvmFun->xjGetLabelByFileName(xjTrainName, xjLabelName);
		xjMatSampleLabel.at<int>(i, 0) = xjLabel;//标签(类别)
	}
#pragma endregion

#pragma region SVM参数和训练模型
	CvSVMParams SVMparameter;//参数
	SVMparameter.svm_type = CvSVM::C_SVC;
	SVMparameter.kernel_type = CvSVM::LINEAR;
	SVMparameter.degree = 1.0;
	SVMparameter.term_crit = cvTermCriteria(CV_TERMCRIT_ITER, 100, 1e-6);

	CvSVM SVM;//训练模型
	SVM.train(xjMatSampleTrain, xjMatSampleLabel, Mat(), Mat(), SVMparameter);
#pragma endregion

#pragma region 预测分类
	string xjResult = "";
	const string xjForecastFolder = "E:/SVM/ImgColor/forecast";
	vector<string> xjForecastFiles;
	xjSvmFun->xjGetFilesNames(xjForecastFolder, xjForecastFiles);
	const int xjForecastSum = xjForecastFiles.size();//预测数据数量
	for (int i = 0; i < xjForecastSum; i++)
	{
		//预测数据特征和标签
		Mat xjForecastData = Mat::zeros(1, xjFeatureSum, CV_32FC1);
		Mat xjForecastLabel;

		string xjForecastName = xjForecastFiles[i];
		Mat xjMatForecast = imread(xjForecastName);
		float Pred = 0, Pgreen = 0, Pblue = 0;
		int red, green, blue;
		float rgbSum = 0;
		for (int r = 1; r < xjMatForecast.rows; r++)
		{
			for (int c = 1; c < xjMatForecast.cols; c++)
			{
				red = xjMatForecast.at<Vec3b>(r, c)[2];
				green = xjMatForecast.at<Vec3b>(r, c)[1];
				blue = xjMatForecast.at<Vec3b>(r, c)[0];
				rgbSum = red + green + blue;
				Pred += red / rgbSum;
				Pgreen += green / rgbSum;
				Pblue += blue / rgbSum;
			}
		}
		Pred /= (xjMatForecast.rows*xjMatForecast.cols);
		Pgreen /= (xjMatForecast.rows*xjMatForecast.cols);
		Pblue /= (xjMatForecast.rows*xjMatForecast.cols);
		//三个特征
		xjForecastData.at<float>(0, 0) = Pred;
		xjForecastData.at<float>(0, 1) = Pgreen;
		xjForecastData.at<float>(0, 2) = Pblue;

		//预测
		SVM.predict(xjForecastData, xjForecastLabel);

		//预测结果
		int xjClassLabel = xjForecastLabel.at<float>(0, 0);
		string xjForecastResult = xjLabelName[xjClassLabel];//结果

		QString qFileFullPath = QString::fromStdString(xjForecastName);
		QFileInfo xjFileInfo(qFileFullPath);
		QString qFileName = xjFileInfo.completeBaseName();
		string xjFileName = qFileName.toStdString();
		putText(xjMatForecast, xjFileName + "" + xjForecastResult, Point(30, 30),
			FONT_HERSHEY_SIMPLEX, 0.5, Scalar(255, 0, 0), 1, 1);
		imshow("结果对比", xjMatForecast);
		waitKey(200);
		xjResult += xjFileName + "--" + xjForecastResult + "\r\n";
	}
#pragma endregion
	QString xjQResult = QString::fromStdString(xjResult);
	QMessageBox::information(NULL, "提示", xjQResult);
}
//获取文件夹下的所有文件名称
void SVMfunction::xjGetFilesNames(const string & xjFolder, vector<string> & xjFiles)
{
    //文件句柄
    long hFile = 0;
    struct _finddata_t fileinfo;  
    std::string p;
    if((hFile = _findfirst(p.assign(xjFolder).append("\\*").c_str(),&fileinfo)) != -1)  
    {  
        do
        {  
            //如果是目录,迭代之  
            //如果不是,加入列表  
            if((fileinfo.attrib & _A_SUBDIR))  
            {  
                if(strcmp(fileinfo.name,".") != 0 && strcmp(fileinfo.name,"..") != 0)  
                    xjGetFilesNames(p.assign(xjFolder).append("\\").append(fileinfo.name), xjFiles);  
            }  
            else
            {
                xjFiles.push_back(p.assign(xjFolder).append("\\").append(fileinfo.name));  
            }  
        } while (_findnext(hFile, &fileinfo) == 0);  
        //_findclose(hFile);  
    }
}
//根据文件名称获取分类标签
int SVMfunction::xjGetLabelByFileName(string FileFullPath, string xjLabelName[])
{
	int xjLabel = 0;

	//文件名称
	QString qFileFullPath = QString::fromStdString(FileFullPath);
	QFileInfo xjFileInfo(qFileFullPath);
	QString qFileName = xjFileInfo.completeBaseName();
	string xjFileName = qFileName.toStdString();

	string labelName;
	string::size_type idx;
	for (int i = 0; i < 3; i++)
	{
		labelName = xjLabelName[i];
		//字符串是否包含子字符串
		if (xjFileName.find(labelName) < xjFileName.length())
		{
			xjLabel = i;
			break;
		}
	}
	return xjLabel;
}

四、结果

最后三个分类是错误的。

五、分析

SVM是监督分类,在训练样本数量不足或者特征不明显的情况下,分类错误的概率会大大提高。

图片见:链接:https://pan.baidu.com/s/1WlnlEfkgmnx8Wvpp4c6Lag 
提取码:j2ds 

百度网盘系统维护,一定要有提取码

【注】本文参考 https://blog.csdn.net/akadiao/article/details/79278072 谢谢

发布了63 篇原创文章 · 获赞 58 · 访问量 8万+

猜你喜欢

转载自blog.csdn.net/xinjiang666/article/details/88352096