KNN算法 针对分类和预测详解 原文地址:KNN算法

原文地址:KNN算法

K Nearest Neighbor算法又叫KNN算法,这个算法是机器学习里面一个比较经典的算法, 总体来说KNN算法是相对比较容易理解的算法。其中的K表示最接近自己的K个数据样本。KNN算法和K-Means算法不同的是,K-Means算法用来聚类,用来判断哪些东西是一个比较相近的类型,而KNN算法是用来做归类的,也就是说,有一个样本空间里的样本分成很几个类型,然后,给定一个待分类的数据,通过计算接近自己最近的K个样本来判断这个待分类数据属于哪个分类。你可以简单的理解为由那离自己最近的K个点来投票决定待分类数据归为哪一类

Wikipedia上的KNN词条中有一个比较经典的图如下:

从上图中我们可以看到,图中的有两个类型的样本数据,一类是蓝色的正方形,另一类是红色的三角形。而那个绿色的圆形是我们待分类的数据。

  • 如果K=3,那么离绿色点最近的有2个红色三角形和1个蓝色的正方形,这3个点投票,于是绿色的这个待分类点属于红色的三角形。
  • 如果K=5,那么离绿色点最近的有2个红色三角形和3个蓝色的正方形,这5个点投票,于是绿色的这个待分类点属于蓝色的正方形。

我们可以看到,机器学习的本质——是基于一种数据统计的方法!那么,这个算法有什么用呢?我们来看几个示例。

产品质量判断

假设我们需要判断纸巾的品质好坏,纸巾的品质好坏可以抽像出两个向量,一个是“酸腐蚀的时间”,一个是“能承受的压强”。如果我们的样本空间如下:(所谓样本空间,又叫Training Data,也就是用于机器学习的数据)

向量X1

耐酸时间(秒)

向量X2

圧强(公斤/平方米)

品质Y

7

7

7

4

3

4

1

4

那么,如果 X1 = 3 和 X2 = 7, 这个毛巾的品质是什么呢?这里就可以用到KNN算法来判断了。

假设K=3,K应该是一个奇数,这样可以保证不会有平票,下面是我们计算(3,7)到所有点的距离。(关于那些距离公式,可以参看K-Means算法中的距离公式

向量X1

耐酸时间(秒)

向量X2

圧强(公斤/平方米)

计算到 (3, 7)的距离

向量Y

7

7

 坏

7

4

 N/A

3

4

 好

1

4

 好

所以,最后的投票,好的有2票,坏的有1票,最终需要测试的(3,7)是合格品。(当然,你还可以使用权重——可以把距离值做为权重,越近的权重越大,这样可能会更准确一些)

注:示例来自这里K-NearestNeighbors Excel表格下载

预测

假设我们有下面一组数据,假设X是流逝的秒数,Y值是随时间变换的一个数值(你可以想像是股票值)

那么,当时间是6.5秒的时候,Y值会是多少呢?我们可以用KNN算法来预测之。

这里,让我们假设K=2,于是我们可以计算所有X点到6.5的距离,如:X=5.1,距离是 | 6.5 – 5.1 | = 1.4, X = 1.2 那么距离是 | 6.5 – 1.2 | = 5.3 。于是我们得到下面的表:

注意,上图中因为K=2,所以得到X=4 和 X =5.1的点最近,得到的Y的值分别为27和8,在这种情况下,我们可以简单的使用平均值来计算:

于是,最终预测的数值为:17.5

注:示例来自这里KNN_TimeSeries Excel表格下载

插值,平滑曲线

KNN算法还可以用来做平滑曲线用,这个用法比较另类。假如我们的样本数据如下(和上面的一样):

要平滑这些点,我们需要在其中插入一些值,比如我们用步长为0.1开始插值,从0到6开始,计算到所有X点的距离(绝对值),下图给出了从0到0.5 的数据:

下图给出了从2.5到3.5插入的11个值,然后计算他们到各个X的距离,假值K=4,那么我们就用最近4个X的Y值,然后求平均值,得到下面的表:

于是可以从0.0, 0.1, 0.2, 0.3 …. 1.1, 1.2, 1.3…..3.1, 3.2…..5.8, 5.9, 6.0 一个大表,跟据K的取值不同,得到下面的图:

注:示例来自这里KNN_Smoothing Excel表格下载


OPENCV运用KNN实现印刷数字分类识别:

[cpp]  view plain  copy
  1. #include "stdafx.h"  
  2. #include"iostream"  
  3. #include"opencv.hpp"  
  4. using namespace std;  
  5. using namespace cv;  
  6.   
  7.   
  8. int _tmain(int argc, _TCHAR* argv[])  
  9. {  
  10.   
  11.     int sample_num = 10, class_num = 10;  
  12.     int size_row = 28,size_col=20;  
  13.     CvMat *train_data = cvCreateMat(sample_num * class_num, size_row * size_col, CV_32FC1);  
  14.     CvMat *train_response = cvCreateMat(sample_num * class_num, 1, CV_32FC1);  
  15.     IplImage* src_image = cvCreateImage(cvSize(size_row, 20), IPL_DEPTH_8U, 1);  
  16.     CvMat*img = cvCreateMat(size_row, size_col, CV_32FC1);  
  17.     CvMat row,data;  
  18.     CvMat row_header, *row1;  
  19.     char path[17]=".//image//00.jpg";  
  20.     for(int i = 0; i < 10; i++)  
  21.     {  
  22.         for (int j = 0; j < 10; j++)  
  23.         {  
  24.             path[10] = i+'0';  
  25.             path[11] = j + '0';  
  26.             src_image = cvLoadImage(path, 0);  
  27.             cvShowImage("pa",src_image);  
  28.             waitKey(1);  
  29.             cvGetRow(train_response, &row, i * sample_num + j);  
  30.             cvSet(&row, cvRealScalar(i));  
  31.   
  32.             cvGetRow(train_data, &row, i * sample_num + j);  
  33.             cvConvertScale(src_image, img, 1.0/255, 0);//scale = 0.0039215 = 1/255;   
  34.             cvGetSubRect(img, &data, cvRect(0, 0, size_col ,size_row));  
  35.             //convert data matrix sizexsize to vecor  
  36.             row1 = cvReshape(&data, &row_header, 1, 1);//new_cn=1,1个通道,new_rows=1,1行  
  37.             cvCopy(row1, &row, NULL);             
  38.         }  
  39.     }  
  40.     CvKNearest *knn = new CvKNearest(train_data, train_response, 0, false, 1);  
  41.   
  42.     CvMat data1;  
  43.     IplImage* pimage=cvCreateImage(cvSize(size_row, 20), IPL_DEPTH_8U, 1);   
  44.     CvMat*img1 = cvCreateMat(size_row , size_col, CV_32FC1);  
  45.     CvMat mathdr, *vec;  
  46.     int result = 0;  
  47.     char test_name[17] = ".//test//110.jpg";  
  48.     for (int i = 0; i < 10; i++)  
  49.     {  
  50.         for (int j = 0; j < 10; j++)  
  51.         {  
  52.             test_name[10] = i + '0';  
  53.             test_name[11] = j + '0';  
  54.             pimage = cvLoadImage(test_name, 0);  
  55.             cvConvertScale(pimage, img1, 1.0 / 255, 0);  
  56.             cvGetSubRect(img1, &data1, cvRect(0, 0, size_col, size_row));  
  57.             vec = cvReshape(&data1, &mathdr, 0, 1);  
  58.             result = knn->find_nearest(vec, 1, 0, 0, 0, 0);  
  59.             cout << i << j << ":";  
  60.             cout << result << endl;  
  61.         }  
  62.     }  
  63.     system("Pause");  
  64.     return 0;  
  65. }  


后记

最后,我想再多说两个事,

1) 一个是机器学习,算法基本上都比较简单,最难的是数学建模,把那些业务中的特性抽象成向量的过程,另一个是选取适合模型的数据样本。这两个事都不是简单的事。算法反而是比较简单的事。

2)对于KNN算法中找到离自己最近的K个点,是一个很经典的算法面试题,需要使用到的数据结构是“最大堆——Max Heap”,一种二叉树。你可以看看相关的算法。

猜你喜欢

转载自blog.csdn.net/weixin_40355324/article/details/80480832