目标
在本教程中,您将学习如何:
- 使用OpenCV函数cv :: ml :: SVM :: train来构建基于SVM的分类器以及cv :: ml :: SVM :: predict测试其性能。
什么是SVM?
支持向量机(SVM)是由分离超平面正式定义的区分分类器。换句话说,给定标记的训练数据(监督学习),该算法输出一个最优的超平面,对新的例子进行分类。
超平面在哪个意义上是最优的?让我们考虑以下简单的问题:
对于属于两类之一的线性可分的一组2D点,找到一条分离的直线。
如果是线性可分的话,那么从上图中可以看出,会有很多的区分方法(BP神经网络即是)。SVM就是选择这么多的最优的一种算法。
- 注意
- 在这个例子中,我们处理笛卡尔平面中的直线和点,而不是高维空间中的超平面和矢量。这是对问题的简化。理解这一点很重要,这是因为我们的直觉更好地建立在易于想象的例子之上。然而,相同的概念适用于要分类的例子位于维数高于2的空间中的任务。
在上面的图片中,您可以看到有多条线路为问题提供了解决方案。他们中的任何一个都比其他人好吗?我们可以直观地定义一个标准来估计线条的价值:如果线条靠近点,则线条不好,因为线条对噪点敏感并且不能正确推广。因此,我们的目标应该是尽可能从所有点上找到通过的路线。
SVM算法的操作基于找到给训练样例最大最小距离的超平面。
这里补充上李航老师的《统计学习方法》中的部分内容,对SVM理论介绍的会更好。
上面是我看到的关于SVM最经典的解释。但关于最优解的问题,使用了拉格朗日方法,这个有点难,解决不了。
程序
#include <opencv2/core.hpp>#include <opencv2/imgproc.hpp>
#include "opencv2/imgcodecs.hpp"
#include <opencv2/highgui.hpp>
#include <opencv2/ml.hpp>
using namespace cv;
using namespace cv::ml;
int main(int, char**)
{
// Data for visual representation
int width = 512, height = 512;
Mat image = Mat::zeros(height, width, CV_8UC3);
// Set up training data
int labels[4] = {1, -1, -1, -1};
float trainingData[4][2] = { {501, 10}, {255, 10}, {501, 255}, {10, 501} };
Mat trainingDataMat(4, 2, CV_32FC1, trainingData);//初始化
Mat labelsMat(4, 1, CV_32SC1, labels);//初始化
// Train the SVM
Ptr<SVM> svm = SVM::create();
svm->setType(SVM::C_SVC);
svm->setKernel(SVM::LINEAR);
svm->setTermCriteria(TermCriteria(TermCriteria::MAX_ITER, 100, 1e-6));
svm->train(trainingDataMat, ROW_SAMPLE, labelsMat);//SVM训练
// Show the decision regions given by the SVM
Vec3b green(0,255,0), blue (255,0,0);
for (int i = 0; i < image.rows; ++i)
for (int j = 0; j < image.cols; ++j)
{
Mat sampleMat = (Mat_<float>(1,2) << j,i);//M_继承类M,对于小型矩阵比较适用。
float response = svm->predict(sampleMat);
if (response == 1)
image.at<Vec3b>(i,j) = green;
else if (response == -1)
image.at<Vec3b>(i,j) = blue;
}
// Show the training data
int thickness = -1;
int lineType = 8;
circle( image, Point(501, 10), 5, Scalar( 0, 0, 0), thickness, lineType );
circle( image, Point(255, 10), 5, Scalar(255, 255, 255), thickness, lineType );
circle( image, Point(501, 255), 5, Scalar(255, 255, 255), thickness, lineType );
circle( image, Point( 10, 501), 5, Scalar(255, 255, 255), thickness, lineType );
// Show support vectors
thickness = 2;
lineType = 8;
Mat sv = svm->getUncompressedSupportVectors();
for (int i = 0; i < sv.rows; ++i)
{
const float* v = sv.ptr<float>(i);
circle( image, Point( (int) v[0], (int) v[1]), 6, Scalar(128, 128, 128), thickness, lineType);
}
imwrite("result.png", image); // save the image
imshow("SVM Simple Example", image); // show it to the user
waitKey(0);
}