k-近邻算法实战2——识别手写数字

from numpy import *
from os import listdir
import operator

#k-近邻算法
def classify0(inX,group,label,k):
    m = group.shape[0]
    inVector = tile(inX,(m,1))-group
    dubleInVector = inVector**2
    sumDubleInVector = dubleInVector.sum(axis=1)
    distances = sumDubleInVector**0.5
    disIndex = distances.argsort()
    labelCount = {}
    for i in range(k):
        labelX = label[disIndex[i]]
        labelCount[labelX] = labelCount.get(labelX,0)+1
    sortLabelCount = sorted(labelCount.items(),key=operator.itemgetter(1),reverse=True)
    return sortLabelCount[0][0]

#将图片转化为矩阵,这里的图片采用文本格式存储
def img2vector(filename):
    returnVector = zeros((1,1024))
    fr = open(filename)
    for i in range(32):
        fileLine = fr.readline()
        for j in range(32):
            returnVector[0,32*i+j] = int(fileLine[j])
    return returnVector

def handwritingClassTest():
    hwLabels = []
    trainingFileList = listdir('D:/BaiduNetdiskDownload/machinelearninginaction/Ch02/digits/trainingDigits')
    m = len(trainingFileList)
    hwVector = zeros((m,1024))
    for i in range(m):
        fr = trainingFileList[i]
        frName = fr.split('.')[0]
        frNameIndex = frName.split('_')[0]
        hwLabels.append(frNameIndex) #获得标签集
        #获得训练集
        hwVector[i,:] = img2vector('D:/BaiduNetdiskDownload/machinelearninginaction/Ch02/digits/trainingDigits/%s' % fr)

#开始测试
    testFileList = listdir('D:/BaiduNetdiskDownload/machinelearninginaction/Ch02/digits/testDigits')
    errorCount = 0
    mTest = len(testFileList)
    for i in range(mTest):
        fr = testFileList[i]
        frName = fr.split('.')[0]
        frNameLabel = frName.split('_')[0]
        inX = img2vector('D:/BaiduNetdiskDownload/machinelearninginaction/Ch02/digits/testDigits/%s' % fr)
        returnLabel = classify0(inX,hwVector,hwLabels,3)
        print('the real result is %s,the test result is %s' % (returnLabel,frNameLabel))
        if returnLabel != frNameLabel:
            errorCount += 1
    print('the total number of errors is %d ' % errorCount)
    print('the total error rate is %f' % (errorCount/float(mTest)))

if __name__ == '__main__':
    handwritingClassTest()

猜你喜欢

转载自blog.csdn.net/lwycc2333/article/details/81558301