OpenCV KNN数字分类

(1)、cv::ml::Knearest类:继承自cv::ml::StateModel,而cv::ml::StateModel又继承自cv::Algorithm;

(2)、create函数:为static,new一个KNearestImpl用来创建一个KNearest对象;

(3)、setDefaultK/getDefaultK函数:在预测时,设置/获取的K值;

(4)、setIsClassifier/getIsClassifier函数:设置/获取应用KNN是进行分类还是回归;

(5)、setEmax/getEmax函数:在使用KDTree算法时,设置/获取Emax参数值;

(6)、setAlgorithmType/getAlgorithmType函数:设置/获取KNN算法类型,目前支持两种:brute_force和KDTree;

(7)、findNearest函数:根据输入预测分类/回归结果。

 1 #include<iostream>
 2 #include <opencv2\opencv.hpp>
 3 using namespace cv;
 4 using namespace std;
 5 #include "test.h"
 6 
 7 int main()
 8 {
 9     Mat img = imread("1.png");
10     Mat gray;
11     cvtColor(img, gray, CV_BGR2GRAY);
12     threshold(gray, gray, 0, 255, CV_THRESH_BINARY);
13     // digits.png为2000 * 1000,其中每个数字的大小为20 * 20,
14     // 总共有5000((2000*1000) / (20*20))个数字,类型为[0~9],
15     // [0~9]10个数字每个数字有5000/10 = 500个样本
16     // 对其分割成单个20 * 20的图像并序列化成(转化成一个一维的数组)
17     int side = 20;
18     int m = gray.rows / side;
19     int n = gray.cols / side;
20     Mat data, labels;
21     for (int i = 0; i < m; i++) {
22 
23         int offsetRow = i * side;
24         for (int j = 0; j < n; j++) {
25 
26             int offsetCol = j * side;
27             // 截取20*20的小块
28             
29                 
30             Mat tmp;
31             
32             gray(Range(offsetRow, offsetRow + side), Range(offsetCol, offsetCol + side)).copyTo(tmp);
33             
34             data.push_back(tmp.reshape(0, 1));  // 序列化转换成一个一维向量
35             labels.push_back(i / 5);            // 每500个为一个label类型            
36         }
37     }
38     data.convertTo(data, CV_32F);
39     cout << "读取结束..." << endl;
40     //****************** 使用KNN算法训练********************//
41     int K = 7;    // 改变K值可能会出现不同的效果,K值越大,识别速度越慢
42     Ptr<TrainData> tData = TrainData::create(data, ROW_SAMPLE, labels);
43     Ptr<KNearest> model = KNearest::create();
44     model->setDefaultK(K);
45     model->setIsClassifier(true);
46     model->train(tData);
47     model->save("KnnTest.xml");
48     ///********************测试模型***************************///
49     Mat test = imread(".\\test\\3.jpg", 0);//截取图像中一个数字
50     Mat bw;
51     threshold(test, bw, 0, 255, CV_THRESH_BINARY);
52     Mat I0 = bw.reshape(0, 1);
53     I0.convertTo(I0, CV_32F);
54     // 开始用KNN预测分类,返回识别结果
55     float r = model->predict(I0);
56     
57 }

猜你喜欢

转载自www.cnblogs.com/hsy1941/p/11717703.html
今日推荐