目录
训练模型 OpenCVSharp.ml.svm.Train()
分类预测 OpenCVSharp.ml.svm.Predict()
有空再继续补充知识点了。。。
核心函数:
训练模型 OpenCVSharp.ml.svm.Train()
svm.Train(trainingDataMat, SampleTypes.RowSample, labelsMat);
svm
: SVM 分类器对象,通过SVM.Create()
创建。trainingDataMat
: 训练数据的特征矩阵,类型为Mat
。每行代表一个样本,每列代表一个特征。SampleTypes.RowSample
: 训练数据的样本类型,指定为RowSample
,表示每行是一个样本。labelsMat
: 训练数据的标签矩阵,类型为Mat
。每行对应训练数据的标签。
分类预测 OpenCVSharp.ml.svm.Predict()
float response = svm.Predict(sampleMat);
svm
: SVM 分类器对象,已经训练好的模型。sampleMat
: 待预测样本的特征矩阵,类型为Mat
。该矩阵表示一个或多个待预测样本的特征信息。- 通过将待预测样本的特征矩阵传递给
Predict()
方法,该方法将返回一个浮点数作为预测结果
安装库和源代码
run之前要记得安装库
源代码如下,可以直接run:
public void Run()
{
// 设置训练数据
int[] labels = { 1, -1, -1, -1 };
float[,] trainingData = { { 501, 10 }, { 255, 10 }, { 501, 255 }, { 10, 501 } };
Mat trainingDataMat = new Mat(4, 2, MatType.CV_32F, trainingData);
Mat labelsMat = new Mat(4, 1, MatType.CV_32SC1, labels);
// 创建并配置SVM分类器
SVM svm = SVM.Create();
svm.Type = SVM.Types.CSvc;
svm.KernelType = SVM.KernelTypes.Linear;
svm.TermCriteria = new TermCriteria(CriteriaTypes.MaxIter, 100, 1e-6);
//使用训练数据和标签进行分类
svm.Train(trainingDataMat, SampleTypes.RowSample, labelsMat);
// 创建可视化图像
int width = 512, height = 512;
Mat image = Mat.Zeros(height, width, MatType.CV_8UC3);
// 根据 SVM 的决策区域设置颜色
Vec3b green = new Vec3b(0, 255, 0), blue = new Vec3b(255, 0, 0);
for (int i = 0; i < image.Rows; i++)
{
for (int j = 0; j < image.Cols; j++)
{
//创建一个包含当前坐标的样本
Mat sampleMat = new Mat(1, 2, MatType.CV_32F, new float[] { j, i });
//使用 SVM 对样本进行预测
float response = svm.Predict(sampleMat);
//根据预测结果设置图像颜色
if (response == 1)
image.At<Vec3b>(i, j) = green;
else if (response == -1)
image.At<Vec3b>(i, j) = blue;
}
}
// 在图像上显示训练数据点
int thickness = -1;
Cv2.Circle(image, new Point(501, 10), 5, new Scalar(0, 0, 0), thickness);
Cv2.Circle(image, new Point(255, 10), 5, new Scalar(255, 255, 255), thickness);
Cv2.Circle(image, new Point(501, 255), 5, new Scalar(255, 255, 255), thickness);
Cv2.Circle(image, new Point(10, 501), 5, new Scalar(255, 255, 255), thickness);
// 在图像上显示支持向量
thickness = 2;
Mat sv = svm.GetSupportVectors();
for (int i = 0; i < sv.Rows; i++)
{
Point2f v = sv.At<Point2f>(i);
Cv2.Circle(image, new Point((int)v.X, (int)v.Y), 6, new Scalar(128, 128, 128), thickness);
}
//保存并显示图像
Cv2.ImWrite("result.png", image);
Cv2.ImShow("SVM Simple Example", image);
Cv2.WaitKey();
return;