一、对不同红色图片进行分类,图片名称为预设颜色,有三种:深红色、粉红色、橘红色,分别对应标签名称"crimson", "pink","tangerine"。
二、利用图片的RGB所占比例作为三个特征,对红色图片进行SVM训练,并进行预测分类。训练所用图片和要预测的图片分布放在两个文件夹中。
三、代码【用到了Qt】分享给有需要的人,代码质量勿喷。
void xjImage::xjSVMtest2()//红色分类
{
#pragma region 数据情况
const string xjSampleFolder = "E:/SVM/ImgColor/train";//样本数据文件夹
vector<string> xjSampleFiles;//用来存储文件名
SVMfunction * xjSvmFun = new SVMfunction();
xjSvmFun->xjGetFilesNames(xjSampleFolder, xjSampleFiles);
const int xjSampleSum = xjSampleFiles.size();//训练样本数量
string xjLabelName[3] = { "crimson", "pink","tangerine" };//类别名称
const int xjClassSum = 3;//样本类别数量
const int xjFeatureSum = 3;//样本特征数量:RGB的均值
//训练数据及标签
Mat xjMatSampleTrain = Mat::zeros(xjSampleSum, xjFeatureSum, CV_32FC1);//据说好像必须是CV_32FC1
Mat xjMatSampleLabel = Mat::zeros(xjSampleSum, 1, CV_32SC1);
#pragma endregion
#pragma region 创建训练数据:3个特征和标签
for (int i = 0; i < xjSampleSum; i++)
{
string xjTrainName = xjSampleFiles[i];
Mat xjMatTrainData = imread(xjTrainName);
float Pred = 0, Pgreen = 0, Pblue = 0;
int red, green, blue;
float rgbSum = 0;
for (int r = 1; r < xjMatTrainData.rows; r++)
{
for (int c = 1; c < xjMatTrainData.cols; c++)
{
red = xjMatTrainData.at<Vec3b>(r, c)[2];
green = xjMatTrainData.at<Vec3b>(r, c)[1];
blue = xjMatTrainData.at<Vec3b>(r, c)[0];
rgbSum = red + green + blue;
Pred += red / rgbSum;
Pgreen += green / rgbSum;
Pblue += blue / rgbSum;
}
}
Pred /= (xjMatTrainData.rows*xjMatTrainData.cols);
Pgreen /= (xjMatTrainData.rows*xjMatTrainData.cols);
Pblue /= (xjMatTrainData.rows*xjMatTrainData.cols);
xjMatSampleTrain.at<float>(i, 0) = Pred;//特征1
xjMatSampleTrain.at<float>(i, 1) = Pgreen;//特征2
xjMatSampleTrain.at<float>(i, 2) = Pblue;//特征3
int xjLabel = xjSvmFun->xjGetLabelByFileName(xjTrainName, xjLabelName);
xjMatSampleLabel.at<int>(i, 0) = xjLabel;//标签(类别)
}
#pragma endregion
#pragma region SVM参数和训练模型
CvSVMParams SVMparameter;//参数
SVMparameter.svm_type = CvSVM::C_SVC;
SVMparameter.kernel_type = CvSVM::LINEAR;
SVMparameter.degree = 1.0;
SVMparameter.term_crit = cvTermCriteria(CV_TERMCRIT_ITER, 100, 1e-6);
CvSVM SVM;//训练模型
SVM.train(xjMatSampleTrain, xjMatSampleLabel, Mat(), Mat(), SVMparameter);
#pragma endregion
#pragma region 预测分类
string xjResult = "";
const string xjForecastFolder = "E:/SVM/ImgColor/forecast";
vector<string> xjForecastFiles;
xjSvmFun->xjGetFilesNames(xjForecastFolder, xjForecastFiles);
const int xjForecastSum = xjForecastFiles.size();//预测数据数量
for (int i = 0; i < xjForecastSum; i++)
{
//预测数据特征和标签
Mat xjForecastData = Mat::zeros(1, xjFeatureSum, CV_32FC1);
Mat xjForecastLabel;
string xjForecastName = xjForecastFiles[i];
Mat xjMatForecast = imread(xjForecastName);
float Pred = 0, Pgreen = 0, Pblue = 0;
int red, green, blue;
float rgbSum = 0;
for (int r = 1; r < xjMatForecast.rows; r++)
{
for (int c = 1; c < xjMatForecast.cols; c++)
{
red = xjMatForecast.at<Vec3b>(r, c)[2];
green = xjMatForecast.at<Vec3b>(r, c)[1];
blue = xjMatForecast.at<Vec3b>(r, c)[0];
rgbSum = red + green + blue;
Pred += red / rgbSum;
Pgreen += green / rgbSum;
Pblue += blue / rgbSum;
}
}
Pred /= (xjMatForecast.rows*xjMatForecast.cols);
Pgreen /= (xjMatForecast.rows*xjMatForecast.cols);
Pblue /= (xjMatForecast.rows*xjMatForecast.cols);
//三个特征
xjForecastData.at<float>(0, 0) = Pred;
xjForecastData.at<float>(0, 1) = Pgreen;
xjForecastData.at<float>(0, 2) = Pblue;
//预测
SVM.predict(xjForecastData, xjForecastLabel);
//预测结果
int xjClassLabel = xjForecastLabel.at<float>(0, 0);
string xjForecastResult = xjLabelName[xjClassLabel];//结果
QString qFileFullPath = QString::fromStdString(xjForecastName);
QFileInfo xjFileInfo(qFileFullPath);
QString qFileName = xjFileInfo.completeBaseName();
string xjFileName = qFileName.toStdString();
putText(xjMatForecast, xjFileName + "" + xjForecastResult, Point(30, 30),
FONT_HERSHEY_SIMPLEX, 0.5, Scalar(255, 0, 0), 1, 1);
imshow("结果对比", xjMatForecast);
waitKey(200);
xjResult += xjFileName + "--" + xjForecastResult + "\r\n";
}
#pragma endregion
QString xjQResult = QString::fromStdString(xjResult);
QMessageBox::information(NULL, "提示", xjQResult);
}
//获取文件夹下的所有文件名称
void SVMfunction::xjGetFilesNames(const string & xjFolder, vector<string> & xjFiles)
{
//文件句柄
long hFile = 0;
struct _finddata_t fileinfo;
std::string p;
if((hFile = _findfirst(p.assign(xjFolder).append("\\*").c_str(),&fileinfo)) != -1)
{
do
{
//如果是目录,迭代之
//如果不是,加入列表
if((fileinfo.attrib & _A_SUBDIR))
{
if(strcmp(fileinfo.name,".") != 0 && strcmp(fileinfo.name,"..") != 0)
xjGetFilesNames(p.assign(xjFolder).append("\\").append(fileinfo.name), xjFiles);
}
else
{
xjFiles.push_back(p.assign(xjFolder).append("\\").append(fileinfo.name));
}
} while (_findnext(hFile, &fileinfo) == 0);
//_findclose(hFile);
}
}
//根据文件名称获取分类标签
int SVMfunction::xjGetLabelByFileName(string FileFullPath, string xjLabelName[])
{
int xjLabel = 0;
//文件名称
QString qFileFullPath = QString::fromStdString(FileFullPath);
QFileInfo xjFileInfo(qFileFullPath);
QString qFileName = xjFileInfo.completeBaseName();
string xjFileName = qFileName.toStdString();
string labelName;
string::size_type idx;
for (int i = 0; i < 3; i++)
{
labelName = xjLabelName[i];
//字符串是否包含子字符串
if (xjFileName.find(labelName) < xjFileName.length())
{
xjLabel = i;
break;
}
}
return xjLabel;
}
四、结果
最后三个分类是错误的。
五、分析
SVM是监督分类,在训练样本数量不足或者特征不明显的情况下,分类错误的概率会大大提高。
图片见:链接:https://pan.baidu.com/s/1WlnlEfkgmnx8Wvpp4c6Lag
提取码:j2ds
百度网盘系统维护,一定要有提取码
【注】本文参考 https://blog.csdn.net/akadiao/article/details/79278072 谢谢