KNN 算法理解

kNN算法又称为k近邻分类(k-nearest neighbor classification)算法。

一、基本思想:

        kNN算法的指导思想是“近朱者赤,近墨者黑”,由你的邻居来推断出你的类别。在距离空间里,如果一个样本的最接近的k个邻居里,绝大多数属于某个类别,则该样本也属于这个类别。俗话叫,“随大流”。

        代表论文:Discriminant Adaptive Nearest Neighbor Classification

 

二、算法描述:

1、算法步骤:

        step.1---初始化距离为最大值
        step.2---计算未知样本和每个训练样本的距离dist
        step.3---得到目前K个最临近样本中的最大距离maxdist
        step.4---如果dist小于maxdist,则将该训练样本作为K-最近邻样本
        step.5---重复步骤2、3、4,直到未知样本和所有训练样本的距离都算完
        step.6---统计K-最近邻样本中每个类标号出现的次数
        step.7---选择出现频率最大的类标号作为未知样本的类标号

2、K的选取:

        如何选择一个最佳的K值取决于数据。一般情况下,在分类时较大的K值能够减小噪声的影响。但会使类别之间的界限变得模糊。比如下图:


        待测样本(绿色圆圈)既可能分到红色三角形类,也可能分到蓝色正方形类。如果k取3,从图可见,待测样本的3个邻居在实线的内圆里,按多数投票结果,它属于红色三角形类,票数1:2.但是如果k取5,那么待测样本的最邻近的5个样本在虚线的圆里,按表决法,它又属于蓝色正方形类,票数2(红色三角形):3(蓝色正方形)。另外还有认为,经验规则,k一般低于训练样本数的平方根。

 

三、优缺点

1、优点

        简单,易于理解,易于实现,无需估计参数,无需训练。适合对稀有事件进行分类(例如当流失率很低时,比如低于0.5%,构造流失预测模型)。特别适合于多分类问题(multi-modal,对象具有多个类别标签),例如根据基因特征来判断其功能分类,kNN比SVM的表现要好。

2、缺点

        懒惰算法,对测试样本分类时的计算量大,内存开销大,评分慢可解释性较差,无法给出决策树那样的规则。

 

四、行业应用

        客户流失预测、欺诈侦 测等(更适合于稀有事件的分类问题)。 

 

五、性能问题

        kNN是一种懒惰算法,平时不好好学习,考试(对测试样本分类)时才临阵磨枪(临时去找k个近邻)。懒惰的后果:构造模型很简单,但在对测试样本分类地的系统开销大,因为要扫描全部训练样本并计算距离。已经有一些方法提高计算的效率,例如压缩训练样本量等。

 

六:测试代码和数据集:

        数据集:http://archive.ics.uci.edu/ml/machine-learning-databases/iris/iris.data

//    KNN.cpp     K-最近邻分类算法
//
////////////////////////////////////////////////////////////////////////////////////////////////////////
#include <stdlib.h>
#include <stdio.h>
#include <memory.h>
#include <string.h>
#include <iostream>
#include <math.h>
#include <fstream>
using namespace std;
////////////////////////////////////////////////////////////////////////////////////////////////////////
//
//    宏定义
//
////////////////////////////////////////////////////////////////////////////////////////////////////////
#define  ATTR_NUM  4                        //属性数目
#define  MAX_SIZE_OF_TRAINING_SET  1000      //训练数据集的最大大小
#define  MAX_SIZE_OF_TEST_SET      100       //测试数据集的最大大小
#define  MAX_VALUE  10000.0                  //属性最大值
#define  K  7
//结构体
struct dataVector {
    int ID;                      //ID号
    char classLabel[15];             //分类标号
    double attributes[ATTR_NUM]; //属性 
};
struct distanceStruct {
    int ID;                      //ID号
    double distance;             //距离
    char classLabel[15];             //分类标号
};

////////////////////////////////////////////////////////////////////////////////////////////////////////
//
//    全局变量
//
////////////////////////////////////////////////////////////////////////////////////////////////////////
struct dataVector gTrainingSet[MAX_SIZE_OF_TRAINING_SET]; //训练数据集
struct dataVector gTestSet[MAX_SIZE_OF_TEST_SET];         //测试数据集
struct distanceStruct gNearestDistance[K];                //K个最近邻距离
int curTrainingSetSize=0;                                 //训练数据集的大小
int curTestSetSize=0;                                     //测试数据集的大小
////////////////////////////////////////////////////////////////////////////////////////////////////////
//
//    求 vector1=(x1,x2,...,xn)和vector2=(y1,y2,...,yn)的欧几里德距离
//
////////////////////////////////////////////////////////////////////////////////////////////////////////
double Distance(struct dataVector vector1,struct dataVector vector2)
{
    double dist,sum=0.0;
    for(int i=0;i<ATTR_NUM;i++)
    {
        sum+=(vector1.attributes[i]-vector2.attributes[i])*(vector1.attributes[i]-vector2.attributes[i]);
    }
    dist=sqrt(sum);
    return dist;
}
////////////////////////////////////////////////////////////////////////////////////////////////////////
//
//    得到gNearestDistance中的最大距离,返回下标
//
////////////////////////////////////////////////////////////////////////////////////////////////////////
int GetMaxDistance()
{
    int maxNo=0;
    for(int i=1;i<K;i++)
    {
        if(gNearestDistance[i].distance>gNearestDistance[maxNo].distance) 
            maxNo = i;
    }
    return maxNo;
}
////////////////////////////////////////////////////////////////////////////////////////////////////////
//
//    对未知样本Sample分类
//
////////////////////////////////////////////////////////////////////////////////////////////////////////
char* Classify(struct dataVector Sample)
{
    double dist=0;
    int maxid=0,freq[K],i,tmpfreq=1;;
    char *curClassLable=gNearestDistance[0].classLabel;
    memset(freq,1,sizeof(freq));
    //step.1---初始化距离为最大值
    for(i=0;i<K;i++)
    {
        gNearestDistance[i].distance=MAX_VALUE;
    }
    //step.2---计算K-最近邻距离
    for(i=0;i<curTrainingSetSize;i++)
    {
        //step.2.1---计算未知样本和每个训练样本的距离
        dist=Distance(gTrainingSet[i],Sample);
        //step.2.2---得到gNearestDistance中的最大距离
        maxid=GetMaxDistance();
        //step.2.3---如果距离小于gNearestDistance中的最大距离,则将该样本作为K-最近邻样本
        if(dist<gNearestDistance[maxid].distance) 
        {
            gNearestDistance[maxid].ID=gTrainingSet[i].ID;
            gNearestDistance[maxid].distance=dist;
            strcpy(gNearestDistance[maxid].classLabel,gTrainingSet[i].classLabel);
        }
    }
    //step.3---统计每个类出现的次数
    for(i=0;i<K;i++)  
    {
        for(int j=0;j<K;j++)
        {
            if((i!=j)&&(strcmp(gNearestDistance[i].classLabel,gNearestDistance[j].classLabel)==0))
            {
                freq[i]+=1;
            }
        }
    }
    //step.4---选择出现频率最大的类标号
    for(i=0;i<K;i++)
    {
        if(freq[i]>tmpfreq)  
        {
            tmpfreq=freq[i];
            curClassLable=gNearestDistance[i].classLabel;
        }
    }
    return curClassLable;
}
////////////////////////////////////////////////////////////////////////////////////////////////////////
//
//    主函数
//
////////////////////////////////////////////////////////////////////////////////////////////////////////

void main()
{   
    char c; 
    char *classLabel="";
    int i,j, rowNo=0,TruePositive=0,FalsePositive=0;
    ifstream filein("iris.data");
    FILE *fp;
    if(filein.fail())
    {
        cout<<"Can't open data.txt"<<endl; 
        return;
    }

    //step.1---读文件 
    while(!filein.eof()) 
    {
        rowNo++;//第一组数据rowNo=1
        if(curTrainingSetSize>=MAX_SIZE_OF_TRAINING_SET) 
        {
            cout<<"The training set has "<<MAX_SIZE_OF_TRAINING_SET<<" examples!"<<endl<<endl;
            break ;
        }  
        //rowNo%3!=0的100组数据作为训练数据集
        if(rowNo%3!=0)
        {   
            gTrainingSet[curTrainingSetSize].ID=rowNo;
            for(int i = 0;i < ATTR_NUM;i++) 
            {     
                filein>>gTrainingSet[curTrainingSetSize].attributes[i];
                filein>>c;
            }   
            filein>>gTrainingSet[curTrainingSetSize].classLabel;
            curTrainingSetSize++;

        }
        //剩下rowNo%3==0的50组做测试数据集
        else if(rowNo%3==0)
        {
            gTestSet[curTestSetSize].ID=rowNo;
            for(int i = 0;i < ATTR_NUM;i++) 
            {    
                filein>>gTestSet[curTestSetSize].attributes[i];
                filein>>c;
            }  
            filein>>gTestSet[curTestSetSize].classLabel;
            curTestSetSize++;
        }
    }
    filein.close();
    //step.2---KNN算法进行分类,并将结果写到文件iris_OutPut.txt
    fp=fopen("iris_OutPut.txt","w+t");
    //用KNN算法进行分类
    fprintf(fp,"************************************程序说明***************************************\n");
    fprintf(fp,"** 采用KNN算法对iris.data分类。为了操作方便,对各组数据添加rowNo属性,第一组rowNo=1!\n");
    fprintf(fp,"** 共有150组数据,选择rowNo模3不等于0的100组作为训练数据集,剩下的50组做测试数据集\n");
    fprintf(fp,"***********************************************************************************\n\n");
    fprintf(fp,"************************************实验结果***************************************\n\n");
    for(i=0;i<curTestSetSize;i++)
    {
        fprintf(fp,"************************************第%d组数据**************************************\n",i+1);
        classLabel =Classify(gTestSet[i]);
        if(strcmp(classLabel,gTestSet[i].classLabel)==0)//相等时,分类正确
        {
            TruePositive++;
        }
        cout<<"rowNo: ";
        cout<<gTestSet[i].ID<<"    \t";
        cout<<"KNN分类结果:      ";

        cout<<classLabel<<"(正确类标号: ";
        cout<<gTestSet[i].classLabel<<")\n";
        fprintf(fp,"rowNo:  %3d   \t  KNN分类结果:  %s ( 正确类标号:  %s )\n",gTestSet[i].ID,classLabel,gTestSet[i].classLabel);
        if(strcmp(classLabel,gTestSet[i].classLabel)!=0)//不等时,分类错误
        {
            // cout<<"   ***分类错误***\n";
            fprintf(fp,"                                                                      ***分类错误***\n");
        }
        fprintf(fp,"%d-最临近数据:\n",K);
        for(j=0;j<K;j++)
        {
            // cout<<gNearestDistance[j].ID<<"\t"<<gNearestDistance[j].distance<<"\t"<<gNearestDistance[j].classLabel[15]<<endl;
            fprintf(fp,"rowNo:  %3d   \t   Distance:  %f   \tClassLable:    %s\n",gNearestDistance[j].ID,gNearestDistance[j].distance,gNearestDistance[j].classLabel);
        }
        fprintf(fp,"\n"); 
    }
    FalsePositive=curTestSetSize-TruePositive;
    fprintf(fp,"***********************************结果分析**************************************\n",i);
    fprintf(fp,"TP(True positive): %d\nFP(False positive): %d\naccuracy: %f\n",TruePositive,FalsePositive,double(TruePositive)/(curTestSetSize-1));
    fclose(fp);
    return;
}


猜你喜欢

转载自blog.csdn.net/fujianfafu/article/details/63253532