kNN分类 (k-nearest neighbor,k近邻法)

核心:物以类聚----根据邻近样本决定测试样本的类别。

一、概念

    所谓邻近样本,就是离它最近的k个样本,通过计算其与所有已知样本的距离来确定。

   ( 距离的计算方式有多种(https://blog.csdn.net/albert201605/article/details/81040556),kNN一般使用的是欧氏距离, 即两点间的空间距离,为两点向量差的L2范数。两个n维向量A(x11,x12,...,x1n)和B(x21,x22,...,x2n)间的欧氏距离为:

                                                                                                                                        )

         所以kNN分类的初始状态就是一些已知类别的样本(全部特征向量表示),对于一个未知类别的新样本,我们找到k个距离最小的 邻居样本,看这k个邻居属于哪个类的多,我们就认为新样本也属于那个类。

示例:

                                                                            

        如图所示,蓝色正方形和红色三角形分别是两个类别的已知样本(训练集),绿色圆形是待分类样本。如果k取3,离它最近的三个样本中有两个三角形一个正方形,三角形数多于正方形,所以认为小圆形属于三角形那一类;如果k取5,离它最近的五个样本中有两个三角形三个正方形,正方形数多于三角形,所以认为小圆形属于正方形那一类。

二、一般过程:

1)计算已训练集中样本与待测样本之间的距离;

2)按距离排序;

3)选取与当前样本距离最小的k个邻居样本;

4)确定此k个样本中各个类别的频率;

5)频率最高的类别作为该样本的预测分类。

三、注意事项:

1.kNN算法的三要素:k值选择、距离度量和分类决策规则都会对分类结果产生重要影响。

       k值选择:通常是不大于20的整数,通常由交叉验证选择最优的k.

       距离度量:不同的距离度量方法所确定的k个邻近点是不同的,会对结果产生影响,一般选用欧氏距离。

      分类决策规则:一般是多数表决,即k个邻居中多的说的算。可以根据不同距离的邻居对该样本产生的影响赋予不同的权重。

2.数据标准化:在开始实现算法之前,我们要考虑一个问题,不同特征的特征值范围可能有很大的差别,例如,我们要分辨一个人的性别,一个女生的身高是1.70m,体重是60kg,一个男生的身高是1.80m,体重是70kg,而一个未知性别的人的身高是1.81m, 体重是64kg,这个人与女生数据点的“距离”的平方 d^2 = ( 1.70 - 1.81 )^2 + ( 60 - 64 )^2 = 0.0121 + 16.0 = 16.0121,而与男生数据点的“距离”的平方d^2 = ( 1.80 - 1.81 )^2 + ( 70 - 64 )^2 = 0.0001 + 36.0 = 36.0001 。可见,在这种情况下,身高差的平方相对于体重差的平方基本可以忽略不计,但是身高对于辨别性别来说是十分重要的。为了解决这个问题,就需要将数据标准化(normalize),把每一个特征值除以该特征的范围,保证标准化后每一个特征值都在0~1之间。

四、例子及python实现:

          见    https://www.cnblogs.com/erbaodabao0611/p/7588840.html

下面两例代码参考

                 https://www.cnblogs.com/buzhizhitong/p/6036417.html,

                 https://blog.csdn.net/zouxy09/article/details/16955347

1.已知四个点的特征向量及其类别如下表所示,问 [1.2, 1.0] 和 [0.1, 0.3]两个样本点属于哪一类?

     

样本点 类别
[1.0, 0.9] A
[1.0, 1.0] A
[0.1, 0.2] B
[0.0, 0.1]

代码:

# -*- coding: utf-8 -*-
"""
Created on Sun Nov  6 16:09:00 2016

@author: Administrator
"""

# Input:
#   newInput:待测的数据点(1xM)
#   dataSet:已知的数据(NxM)
#   labels:已知数据的标签(1xM)
#   k:选取的最邻近数据点的个数
#
# Output:
#   待测数据点的分类标签
#

from numpy import *


# creat a dataset which contain 4 samples with 2 class
def createDataSet():
    # creat a matrix: each row as a sample
    group = array([[1.0, 0.9], [1.0, 1.0], [0.1, 0.2], [0.0, 0.1]])
    labels = ['A', 'A', 'B', 'B']
    return group, labels


# classify using KNN
def KNNClassify(newInput, dataSet, labels, k):
    numSamples = dataSet.shape[0]  # row number
    # step1:calculate Euclidean distance
    # tile(A, reps):Constract an array by repeating A reps times
    diff = tile(newInput, (numSamples, 1)) - dataSet
    squreDiff = diff ** 2
    squreDist = sum(squreDiff, axis=1)  # sum if performed by row
    distance = squreDist ** 0.5

    # step2:sort the distance
    # argsort() returns the indices that would sort an array in a ascending order
    sortedDistIndices = argsort(distance)

    classCount = {}
    for i in range(k):
        # choose the min k distance
        voteLabel = labels[sortedDistIndices[i]]

        # step4:count the times labels occur
        # when the key voteLabel is not in dictionary classCount,
        # get() will return 0
        classCount[voteLabel] = classCount.get(voteLabel, 0) + 1
    # step5:the max vote class will return
    maxCount = 0
    for k, v in classCount.items():
        if v > maxCount:
            maxCount = v
            maxIndex = k

    return maxIndex


# test

dataSet, labels = createDataSet()

testX = array([1.2, 1.0])
k = 3
outputLabel = KNNClassify(testX, dataSet, labels, 3)

print("Your input is:", testX, "and classified to class: ", outputLabel)

testX = array([0.1, 0.3])
k = 3
outputLabel = KNNClassify(testX, dataSet, labels, 3)

print("Your input is:", testX, "and classified to class: ", outputLabel)

 2.手写数字识别例子,代码及数据已调试通过,见https://download.csdn.net/download/albert201605/10570158

参考:

1.李航,《统计学习方法》第3章

2.郑捷,《机器学习算法原理与编程实践》

3.https://blog.csdn.net/zouxy09/article/details/16955347

4.https://www.cnblogs.com/buzhizhitong/p/6036417.html

5.https://www.cnblogs.com/erbaodabao0611/p/7588840.html

6.https://www.cnblogs.com/magic-girl/p/python-kNN.html

猜你喜欢

转载自blog.csdn.net/Albert201605/article/details/81265004