k近邻算法python实现 -- 《机器学习实战》

版权声明:转载请先联系并注明出处! https://blog.csdn.net/u010552731/article/details/89346576

k近邻算法python实现 -- 《机器学习实战》

 
     
  1 '''
  2 Created on Nov 06, 2017
  3 kNN: k Nearest Neighbors
  4 
  5 Input:      inX: vector to compare to existing dataset (1xN)
  6             dataSet: size m data set of known vectors (NxM)
  7             labels: data set labels (1xM vector)
  8             k: number of neighbors to use for comparison (should be an odd number)
  9 
 10 Output:     the most popular class label
 11 
 12 @author: Liu Chuanfeng
 13 '''
 14 import operator
 15 import numpy as np
 16 import matplotlib.pyplot as plt
 17 from os import listdir
 18 
 19 def classify0(inX, dataSet, labels, k):
 20     dataSetSize = dataSet.shape[0]
 21     diffMat = np.tile(inX, (dataSetSize,1)) - dataSet
 22     sqDiffMat = diffMat ** 2
 23     sqDistances = sqDiffMat.sum(axis=1)
 24     distances = sqDistances ** 0.5
 25     sortedDistIndicies = distances.argsort()
 26     classCount = {}
 27     for i in range(k):
 28         voteIlabel = labels[sortedDistIndicies[i]]
 29         classCount[voteIlabel] = classCount.get(voteIlabel,0) + 1
 30     sortedClassCount = sorted(classCount.items(), key = operator.itemgetter(1), reverse = True)
 31     return sortedClassCount[0][0]
 32 
 33 #数据预处理,将文件中数据转换为矩阵类型
 34 def file2matrix(filename):
 35     fr = open(filename)
 36     arrayLines = fr.readlines()
 37     numberOfLines = len(arrayLines)
 38     returnMat = np.zeros((numberOfLines, 3))
 39     classLabelVector = []
 40     index = 0
 41     for line in arrayLines:
 42         line = line.strip()
 43         listFromLine = line.split('\t')
 44         returnMat[index,:] = listFromLine[0:3]
 45         classLabelVector.append(int(listFromLine[-1]))
 46         index += 1
 47     return returnMat, classLabelVector
 48 
 49 #数据归一化处理:由于矩阵各列数据取值范围的巨大差异导致各列对计算结果的影响大小不一,需要归一化以保证相同的影响权重
 50 def autoNorm(dataSet):
 51     maxVals = dataSet.max(0)
 52     minVals = dataSet.min(0)
 53     ranges = maxVals -  minVals
 54     m = dataSet.shape[0]
 55     normDataSet = (dataSet - np.tile(minVals, (m, 1))) / np.tile(ranges, (m, 1))
 56     return normDataSet, ranges, minVals
 57 
 58 #约会网站测试代码
 59 def datingClassTest():
 60     hoRatio = 0.10
 61     datingDataMat, datingLabels = file2matrix('datingTestSet2.txt')
 62     normMat, ranges, minVals = autoNorm(datingDataMat)
 63     m = normMat.shape[0]
 64     numTestVecs = int(m * hoRatio)
 65     errorCount = 0.0
 66     for i in range(numTestVecs):
 67         classifyResult = classify0(normMat[i,:], normMat[numTestVecs:m, :], datingLabels[numTestVecs:m], 3)
 68         print('theclassifier came back with: %d, the real answer is: %d' % (classifyResult, datingLabels[i]))
 69         if ( classifyResult != datingLabels[i]):
 70             errorCount += 1.0
 71         print ('the total error rate is: %.1f%%' % (errorCount/float(numTestVecs) * 100))
 72 
 73 #约会网站预测函数
 74 def classifyPerson():
 75     resultList = ['not at all', 'in small doses', 'in large doses']
 76     percentTats = float(input("percentage of time spent playing video games?"))
 77     ffMiles = float(input("frequent flier miles earned per year?"))
 78     iceCream = float(input("liters of ice cream consumed per year?"))
 79     datingDataMat, datingLabels = file2matrix('datingTestSet2.txt')
 80     normMat, ranges, minVals = autoNorm(datingDataMat)
 81     inArr = np.array([ffMiles, percentTats, iceCream])
 82     classifyResult = classify0((inArr-minVals)/ranges, normMat, datingLabels, 3)
 83     print ("You will probably like this persoon:", resultList[classifyResult - 1])
 84 
 85 
 86 #手写识别系统#============================================================================================================
 87 #数据预处理:输入图片为32*32的文本类型,将其形状转换为1*1024
 88 def img2vector(filename):
 89     returnVect = np.zeros((1, 1024))
 90     fr = open(filename)
 91     for i in range(32):
 92         lineStr = fr.readline()
 93         for j in range(32):
 94             returnVect[0, 32*i+j] = int(lineStr[j])
 95     return returnVect
 96 
 97 #手写数字识别系统测试代码
 98 def handwritingClassTest():
 99     hwLabels = []
100     trainingFileList = listdir('C:\\Private\\PycharmProjects\\Algorithm\\kNN\digits\\traingDigits')
101     m = len(trainingFileList)
102     trainingMat = np.zeros((m, 1024))
103     for i in range(m):                                     #|
104         fileNameStr = trainingFileList[i]                  #|
105         fileName = fileNameStr.split('.')[0]               #| 获取训练集路径下每一个文件,分割文件名,将第一个数字作为标签存储在hwLabels中
106         classNumber = int(fileName.split('_')[0])          #|
107         hwLabels.append(classNumber)                       #|
108         trainingMat[i,:] = img2vector('C:\\Private\\PycharmProjects\\Algorithm\\kNN\digits\\traingDigits\\%s' % fileNameStr)    #变换矩阵形状: from 32*32 to 1*1024
109     testFileList = listdir('C:\\Private\\PycharmProjects\\Algorithm\\kNN\digits\\testDigits')
110     errorCount = 0.0
111     mTest = len(testFileList)
112     for i in range(mTest):              #同训练集
113         fileNameStr = testFileList[i]
114         fileName = fileNameStr.split('.')[0]
115         classNumber = int(fileName.split('_')[0])
116         vectorUnderTest = img2vector('C:\\Private\\PycharmProjects\\Algorithm\\kNN\digits\\testDigits\\%s' % fileNameStr)
117         classifyResult = classify0(vectorUnderTest, trainingMat, hwLabels, 3)   #计算欧氏距离并分类,返回计算结果
118         print ('The classifier came back with: %d, the real answer is: %d' % (classifyResult, classNumber))
119         if (classifyResult != classNumber):
120             errorCount += 1.0
121     print ('The total number of errors is: %d' % (errorCount))
122     print ('The total error rate is: %.1f%%' % (errorCount/float(mTest) * 100))
123 
124 # Simple unit test of func: file2matrix()
125 #datingDataMat, datingLabels = file2matrix('datingTestSet2.txt')
126 #print (datingDataMat)
127 #print (datingLabels)
128 
129 # Usage of figure construction of matplotlib
130 #fig=plt.figure()
131 #ax = fig.add_subplot(111)
132 #ax.scatter(datingDataMat[:,1], datingDataMat[:,2], 15.0*np.array(datingLabels), 15.0*np.array(datingLabels))
133 #plt.show()
134 
135 # Simple unit test of func: autoNorm()
136 #normMat, ranges, minVals = autoNorm(datingDataMat)
137 #print (normMat)
138 #print (ranges)
139 #print (minVals)
140 
141 # Simple unit test of func: img2vector
142 #testVect = img2vector('C:\\Private\\PycharmProjects\\Algorithm\\kNN\digits\\testDigits\\0_13.txt')
143 #print (testVect[0, 32:63] )
144 
145 #约会网站测试
146 datingClassTest()
147 
148 #约会网站预测
149 classifyPerson()
150 
151 #手写数字识别系统预测
152 handwritingClassTest()
 
     
Output:

theclassifier came back with: 3, the real answer is: 3
the total error rate is: 0.0%
theclassifier came back with: 2, the real answer is: 2
the total error rate is: 0.0%
theclassifier came back with: 1, the real answer is: 1
the total error rate is: 0.0%

...

theclassifier came back with: 2, the real answer is: 2
the total error rate is: 4.0%
theclassifier came back with: 1, the real answer is: 1
the total error rate is: 4.0%
theclassifier came back with: 3, the real answer is: 1
the total error rate is: 5.0%

percentage of time spent playing video games?10
frequent flier miles earned per year?10000
liters of ice cream consumed per year?0.5
You will probably like this persoon: in small doses

...

The classifier came back with: 9, the real answer is: 9
The total number of errors is: 27
The total error rate is: 6.8%

 Reference:

《机器学习实战》

posted @ 2017-11-08 21:13 刘川枫 阅读( ...) 评论( ...) 编辑 收藏

猜你喜欢

转载自blog.csdn.net/u010552731/article/details/89346576