李航《统计学习方法》K近邻代码

李航《统计学习方法》K近邻代码

import numpy as np
import pandas as pd
import time

#加载数据集,使用numpy进行加载
def loadData(filename):
    print('start to load data')
    dataArr = []
    labelArr = []
    fr = open(filename,'r')
    for line in fr.readlines():
        curLine = line.strip().split(',')
        dataArr.append([int(num)/255 for num in curLine[1:]])
        labelArr.append(int(curLine[0]))

    return dataArr,labelArr

#计算两个向量之间的欧氏距离,p=2
def calcDist(x1,x2):

    #x1 - x2 两个大小为1 * n 的向量相减
    return np.sqrt(np.sum(np.square(x1 - x2)))


def getClosest(trainDataMat, trainLabelMat, x, topK):
    #建立一个存放向量x与训练集中样本距离的列表
    #列表的长度为训练集的长度
    distList = [0] * len(trainDataMat)
    #遍历训练集中的所有样本点,计算与x之间的距离
    for i in range(len(trainDataMat)):
        xi = trainDataMat[i]
        curDist = calcDist(xi,x)
        distList[i] = curDist
    #距离列表进行升序排列,然后取前K个
    #np.argsort()对矩阵按axis进行排序并返回排序后的下标
    topKList = np.argsort(np.array(distList))[:topK]
    #这里之所以建立一个大小为10的list是因为mnist数据集的类别数就是10,每个位置存放的是对应index的个数
    #目的是为了找出topK中数量最多的index
    labelList = [0] * 10

    for index in topKList:
        labelList[int(trainLabelMat[index])] += 1

    return labelList.index(max(labelList))


def model_test(trainDataArr,trainLabelArr,testDataArr,testLabelArr,topK):
    print('start to test')
    trainDataMat = np.mat(trainDataArr); trainLabelMat = np.mat(trainLabelArr).T
    testDataMat = np.mat(testDataArr); testlabelMat = np.mat(testLabelArr).T

    errorCnt = 0

    for i in range(200):
        print('test %d:%d' %(i+1,200))
        x = testDataMat[i]
        y = getClosest(trainDataMat,trainLabelMat,x,topK)
        if y != testlabelMat[i]:
            errorCnt += 1

    return 1 - (errorCnt / 200)


if __name__ == "__main__":
    #获取程序开始时的时间
    start = time.time()
    #获取训练数据集
    trainDataArr,trainLabelArr = loadData('./mnist_train.csv')
    testDataArr,testLabelArr = loadData('./mnist_test.csv')
    #计算测试集的准确率
    accur = model_test(trainDataArr,trainLabelArr,testDataArr,testLabelArr,25)
    end = time.time()
    #打印正确率
    print('the accuracy rate is:',accur)
    #打印程序运行时间
    print('the run time is:',end - start)
    

猜你喜欢

转载自blog.csdn.net/m0_45388819/article/details/113756475