k-近邻(KNN)算法的应用

KNN约会配对

from numpy import *
from os import listdir
import operator

def classify0(inX, dataSet, labels, k):
    dataSetSize = dataSet.shape[0]
    diffMat = tile(inX, (dataSetSize, 1)) - dataSet
    sqDiffMat = diffMat**2
    sqDistances = sqDiffMat.sum(axis=1)
    distances = sqDistances**0.5
    sortedDistIndicies = distances.argsort()
    classCount={}
    for i in range(k):
        voteIlabel = labels[sortedDistIndicies[i]]
        classCount[voteIlabel] = classCount.get(voteIlabel, 0) + 1
    sortedClassCount = sorted(classCount.items(),
                              key=operator.itemgetter(1), reverse=True)
    return sortedClassCount[0][0]
def img2vector(filename):
    """
    返回每个文件的前32行的前32个数字,即整个数字
    :param filename:
    :return:
    """
    returnVc = zeros((1, 1024))
    fr = open(filename)
    for i in range(32):
        lineStr = fr.readline()
        for j in range(32):
            returnVc[0, 32*i+j] = int(lineStr[j])
    return returnVc

def handWriteClassTest():
    hwLables = []
    # 获取该目录下的文件
    trainingMatFileList = listdir('trainingDigits')
    # 文件个数
    m = len(trainingMatFileList)
    trainingMat = zeros((m, 1024))
    # 对文件名进行切割并获取训练集
    for i in range(m):
        fileNameStr = trainingMatFileList[i]
        fileStr = fileNameStr.split('.')[0]
        classNumStr = int(fileStr.split('_')[0])
        hwLables.append(classNumStr)
        # 获取每个文件的前32行的前32个数字,即整个数字
        trainingMat[i, :] = img2vector('trainingDigits/%s' % fileNameStr)

    # testFileList = listdir('testDigits')
    testFileList = listdir('test')
    errorCount = 0.0
    mTest = len(testFileList)
    # 测试
    for i in range(mTest):
        fileNameStr = testFileList[i]
        # 切割出该文件所代表的数字
        fileStr = fileNameStr.split('.')[0]
        classNumStr = int(fileStr.split('_')[0])
        # 获取特征变量,并判断其是哪个数字
        vectorUnderTest = img2vector('test/%s' % fileNameStr)
        classifierResult = classify0(vectorUnderTest, trainingMat, hwLables, 3)
        print("返回结果为: %d, 答案为: %d" % (classifierResult, classNumStr))
        if (classifierResult != classNumStr): errorCount += 1.0
    print("错误的个数: %d" % errorCount)
    print("错误率: %f" % (errorCount/float(mTest)))


handWriteClassTest()

猜你喜欢

转载自blog.csdn.net/qecode/article/details/78725519
今日推荐