机器学习实战 knn算法

机器学习实战这本书

第二章KNN算法的代码及注释

knn函数

# -- coding: utf-8 --
#kNN

from numpy import*
import operator
from os import listdir

#数据标签
def createDataSet ():
    group = array([[1.0,1.1],[1.0,1.0],[0,0],[0,0.1]])
    labels = ['A','A','B','B']
    return group, labels

#数据准备
def file2matrix(filename):
    fr = open(filename)              #打开txt文件
    arrayOLines = fr.readlines()      #读取文件的每一行 为列表存储
    numberOfLines = len(arrayOLines)  #计算列表的长度 多少行
    returnMat = zeros((numberOfLines,3)) #返回矩阵  0填充的数组 行为numberOfLines 列为3
    classLabelVector = []
    index = 0
    for line in arrayOLines:
        line = line.strip()         # 去除回车字符
        listFromline = line.split('\t')  # 通过指定分隔符对字符串进行切片
        returnMat[index,:] = listFromline[0:3]
        classLabelVector.append(int(listFromline[-1]))  # 传入的对象添加到现有列表中
        index += 1
    return returnMat,classLabelVector

#归一化函数
def autoNorm(dataSet):
    minVals = dataSet.min(0)    #表示从列中选取最小值
    maxvals = dataSet.max(0)
    ranges = maxvals - minVals
    normDataSet = zeros(shape(dataSet))
    m = dataSet.shape[0]         #查看矩阵或者数组的维数  行数
    normDataSet = dataSet - tile(minVals, (m,1))       #tile 构造矩阵 拓展成m*3矩阵  行复制M次 列复制1次
    normDataSet = normDataSet/tile(ranges,(m,1))       #对应位置相除  不是矩阵除法
    return normDataSet, ranges, minVals

#KNN算法
def classify0(inX, dataSet, labels, k):     #一次计算一个点(样本)
    dataSetSize = dataSet.shape[0]              #查看矩阵或者数组的维数 4
    diffMat = tile(inX, (dataSetSize,1)) - dataSet #tile共有2个参数,A指待输入数组,reps则决定A重复的次数。整个函数用于重复数组A来构建新的数组
    sqDiffmat = diffMat ** 2                    #两个乘号就是乘方
    sqDistances = sqDiffmat.sum(axis=1)      #numpy.sum(a)都能将列表a中的所有元素求和并返回,axis=1以后就是将一个矩阵的每一行向量相加
    distance = sqDistances ** 0.5
    sorteDistIndicies = distance.argsort()   #数组值从小到大  索引值的排序
    classCount = {}                          #字典
    for i in range(k):                       #rang代表从0到K(不包含K)  K为KNN中的K
        voteIlabel = labels[sorteDistIndicies[i]]
        classCount[voteIlabel] = classCount.get(voteIlabel,0) + 1 #字典中要查找的键。如果指定键的值不存在时,返回该默认值0。
    sortedClassCount = sorted(classCount.iteritems(),       #对所有可迭代的对象进行排序操作
                              key = operator.itemgetter(1), reverse=True)   #对象的第二个域的值即数字,reverse = True 降序 , reverse = False 升序
    return sortedClassCount[0][0]

#测试数据
def datingClassTest():
    hoRatio = 0.10
    datingDateMat ,datingLables = file2matrix('datingTestSet2.txt')
    normMat, ranges, minVals = autoNorm(datingDateMat)
    m = normMat.shape[0]   #行数 代表样本个数
    numTestVecs = int(m*hoRatio)    #取其中10%作为测试样本
    errorCount = 0.0
    for i in range(numTestVecs):       #rang代表从0到numTestVecs(不包含本身)
        classifierResult = classify0(normMat[i,:],normMat[numTestVecs:m,:],
                                     datingLables[numTestVecs:m], 3)
        print ("the classifier came back with: %d, the real answer is: %d"
                %(classifierResult, datingLables[i]))
        if (classifierResult != datingLables[i]): errorCount += 1.0
    print ("the total error rate is %f" %(errorCount/float(numTestVecs)))


#手写体识别 数据处理
def img2vector(filename):
    returnVect = []
    fr = open(filename)
    arrayOLines = fr.readlines()     #读取txt文件中的每一行
    for line in arrayOLines:         #遍历每一行 逐行处理
        line = line.strip()          #删去每行回车符
        lens = len(line)
        for i in range(lens):
            returnVect.append(int(line[i]))   #将每行字符压入向量中
    return returnVect

#手写体识别 分类
def handwritingClassTest():
    hwlabels = []
    trainingFileList = listdir('trainingDigits')    #获得文件夹里文件名的信息
    m = len(trainingFileList)                        #计算文件夹里文件个数
    trainingMat = zeros((m,1024))
    for i in range(m):
        fileNameStr = trainingFileList[i]           #遍历每一个文件
        hwlabels.append(int(fileNameStr.split('_')[0]))  #记录文件的标签
        trainingMat[i,:] = img2vector('trainingDigits/%s' %fileNameStr)   #处理文件数据 放到一个矩阵里

    testFileList = listdir('testDigits')
    n = len(testFileList)
    errorcount = 0.0
    for i in range(n):
        fileNametext = testFileList[i]
        realResult = (int(fileNametext.split('_')[0]))
        textInx = img2vector('testDigits/%s' % fileNametext)
        textResult = classify0(textInx, trainingMat, hwlabels, 3)
        if(textResult != realResult):  errorcount += 1.0
        print ("the classifier came back with: %d, the real answer is: %d"
                %(textResult, realResult))
    print ("the total error rate is %f" % (errorcount / float(n)))


猜你喜欢

转载自blog.csdn.net/fm904813255/article/details/80268859