所谓k-近邻算法,即是给定一个训练数据集,对新的输入实例,在训练数据集中找到与该实例最邻近的k个实例(也就是上面所说的k个邻居), 这k个实例的多数属于某个类,就把该输入实例分类到这个类中。
k-近邻算法中,最重要的是距离计算,这些公式在初、高中就已经学习过了。但是,那时候学习的是基本的三维空间。我们知道,三维空间中
两点的距离计算公式,
这里拓展到多维度空间。分类问题,一般输入为n维度的特征(比如,图像的话一般为
矩阵,需要先转换为
维度)。那么,对于
维度特征两点
距离的计算公式,
- inX是一个 的特征向量(inX是未知点,需要被分类),相当于dataSet中的一行。比如,
- dataSet是一个 带标签(labels)的数据集,每行表示一条数据,总共m条数据,比如,
k-近邻算法(kNN.py)
def classify0(inX, dataSet, labels, k):
dataSetSize = dataSet.shape[0]
diffMat = np.tile(inX, (dataSetSize, 1)) - dataSet
sqDiffMat = diffMat**2
sqDistances = sqDiffMat.sum(axis=1)
distances = sqDistances**0.5
sortedDistIndicies = distances.argsort()
classCount = {}
for i in range(k):
voteIlabel = labels[sortedDistIndicies[i]]
classCount[voteIlabel] = classCount.get(voteIlabel, 0)+1
sortedClassCount = sorted(classCount.items(), key = operator.itemgetter(1), reverse=True)
return sortedClassCount[0][0]
数据集
def createDataSet():
dataSet = np.array([[1.0,1.1],[1.0,1.0],[0,0],[0,0.1]])
labels = ['A','A','B','B']
return dataSet, labels
运行k-近邻算法
if __name__=='__main__':
dataSet, labels = createDataSet()
ret = classify0([0.3,0.5],dataSet, labels , 3)
print(ret)
那么,上面的k近邻算法classify0是如何计算的呢?首先,计算inX与每个已知点的距离,这时候需要先将inX拉伸,然后与dataSet做差运算。
dataSetSize = dataSet.shape[0]
diffMat = np.tile(inX, (dataSetSize, 1)) - dataSet
以上,相当于完成了距离计算公式中的减法 操作(注意,这里有多个特征),同时这里打算一次性,计算被预测点与所有已知点的距离。(dataSet每行是一个已知点的特征,即 )
sqDiffMat = diffMat**2
sqDistances = sqDiffMat.sum(axis=1)
distances = sqDistances**0.5
接着,将所得的差进行平方。然后,再求和,此时得到一个列向量 。最后,进行开根号计算,得出被预测点与所有已知点的距离(组成一个向量 )。
后续的工作是选择距离最小的k个点了,那么是如何完成的呢?先看几个函数的作用。
1)argsort函数,很显然返回的是由小至大的排序索引(比如,以下最小的为 ,其次为 )。
In [1]:a = np.array([1.5,5.3,3.7,2.0,6.8])
In [2]:a.argsort()
Out[3]: array([0, 3, 2, 1, 4], dtype=int64)
2)classCount.get(voteIlabel, 0),字典
的get方法,表示如果key不存在,则value用0来代替。
3)operator.itemgetter(1),排序函数这里指定以字典的第二项,即value大小作为排序依据。
References:
[1] 李锐, 李鹏, 曲亚东, 王斌[译]. 机器学习实战[M]. 北京:人民邮电出版社, 2013.
[2] k近邻算法,百度百科,https://baike.baidu.com/item/k%E8%BF%91%E9%82%BB%E7%AE%97%E6%B3%95,2018-6-27
© qingdujun
2018-6-27 于 北京 怀柔
附录(0-9手写数字识别):
数据集地址:https://pan.baidu.com/s/1ANXq1AR84Y2e3A80T-39Fg 密码:ztik
# -*- coding: utf-8 -*-
"""
Created on Wed Jun 27 17:33:51 2018
@author: qingdujun
"""
import numpy as np
import operator,os
def classify0(inX, dataSet, labels, k):
dataSetSize = dataSet.shape[0]
diffMat = np.tile(inX, (dataSetSize, 1)) - dataSet
sqDiffMat = diffMat**2
sqDistances = sqDiffMat.sum(axis=1)
distances = sqDistances**0.5
sortedDistIndicies = distances.argsort()
classCount = {}
for i in range(k):
voteIlabel = labels[sortedDistIndicies[i]]
classCount[voteIlabel] = classCount.get(voteIlabel, 0)+1
sortedClassCount = sorted(classCount.items(), key = operator.itemgetter(1), reverse=True)
return sortedClassCount[0][0]
def img2vector(filename):
returnVect = np.zeros((1,1024)) # 32*32 = 1024
fr = open(filename)
for i in range(32):
lineStr = fr.readline()
for j in range(32):
returnVect[0,32*i+j] = int(lineStr[j])
return returnVect
def handWritingClassTest():
hwLabels = []
trainingFileList = os.listdir('trainingDigits')
m = len(trainingFileList)
trainingMat = np.zeros((m,1024))
for i in range(m):
fileNameStr = trainingFileList[i]
fileStr = fileNameStr.split('.')[0]
classNumStr = int(fileStr.split('_')[0])
hwLabels.append(classNumStr)
trainingMat[i,:]=img2vector('trainingDigits/%s'%fileNameStr)
testFileList = os.listdir('testDigits')
errorCount = 0.0
mTest = len(testFileList)
for i in range(mTest):
fileNameStr = testFileList[i]
fileStr = fileNameStr.split('.')[0]
classNumStr = int(fileStr.split('_')[0])
vectorUnderTest = img2vector('testDigits/%s'%fileNameStr)
classifierResult = classify0(vectorUnderTest, trainingMat, hwLabels, 3)
print("the classifier came back with:%d, the real answer is:%d"%(classifierResult, classNumStr))
if (classifierResult != classNumStr):
errorCount += 1.0
print("\nthe total number of error is:%d"%errorCount)
print("\nthe total error rate is:%f"%(errorCount/float(mTest)))
if __name__=='__main__':
handWritingClassTest()