代码理解-机器学习实战k近邻算法(kNN)笔记(Python3)

版权声明:欢迎去我的新家https://www.jianshu.com/u/906a78709f1d https://blog.csdn.net/dongyanwen6036/article/details/81712484

今天学习了《机器学习实战》这本书介绍的第一个机器学习算法—k近邻算法。争取理解代码每一步。

  • code and analysis
'''
Created on Sep 16, 2010
kNN: k Nearest Neighbors

Input:      inX: vector to compare to existing dataset (1xN)
            dataSet: size m data set of known vectors (NxM)
            labels: data set labels (1xM vector)
            k: number of neighbors to use for comparison (should be an odd number)

Output:     the most popular class label

@author: pbharrin
'''
from numpy import *
import operator
from os import listdir

def classify0(inX, dataSet, labels, k):
        """
        inX 是输入的测试样本,当前的点,是一个[x, y] 大小1*N样式的
        dataset 是训练样本集,大小N*M样式的
        labels 是训练样本标签,大小1*M样式的
        k 是top k最相近的.

        tile属于numpy模块下边的函数
        tile(A, reps)返回一个shape=reps的矩阵,矩阵的每个元素是A
        比如 A=[0,1,2] 那么,tile(A, 2)= [0, 1, 2, 0, 1, 2]
        tile(A,(2,2)) = [[0, 1, 2, 0, 1, 2],
                    [0, 1, 2, 0, 1, 2]]

        """
    dataSetSize = dataSet.shape[0]
    diffMat = tile(inX, (dataSetSize,1)) - dataSet
    sqDiffMat = diffMat**2
    sqDistances = sqDiffMat.sum(axis=1)
    distances = sqDistances**0.5
    sortedDistIndicies = distances.argsort()     
    classCount={}          
    for i in range(k):
        '''
        range(k)等价range(0,k)
        index = sortedDistIndicies[i]是第i个最相近的样本下标
        voteIlabel = labels[index]是样本index对应的分类结果('A' or 'B')
        classCount.get(voteIlabel, 0)返回voteIlabel的值,如果不存在,则返回0
        然后将票数增1
        把分类结果进行排序,然后返回得票数最多的分类结果
        '''
        voteIlabel = labels[sortedDistIndicies[i]]
        classCount[voteIlabel] = classCount.get(voteIlabel, 0) + 1
    sortedClassCount = sorted(classCount.items(), key=operator.itemgetter(1), reverse=True)
        """
        sorted(iterable[, cmp[, key[, reverse]]])
        iterable -- 可迭代对象。
        cmp -- 比较的函数,这个具有两个参数,参数的值都是从可迭代对象中取出,此函数必须遵守的规则为,大于则返回1,小于则返回-1,等于则返回0。
        key -- 主要是用来进行比较的元素,只有一个参数,具体的函数的参数就是取自于可迭代对象中,指定可迭代对象中的一个元素来进行排序。
        reverse -- 排序规则,reverse = True 降序 , reverse = False 升序(默认)。

        dict.get(key, default=None)
        key -- 字典中要查找的键。
        default -- 如果指定键的值不存在时,返回该默认值值。
        返回指定键的值,如果值不在字典中返回默认值 None。
        """
    return sortedClassCount[0][0]

def createDataSet():
    """
    函数作用:构建一组训练数据(训练样本),共4个样本
    同时给出了这4个样本的标签,及labels
    """
    group = array([[1.0,1.1],[1.0,1.0],[0,0],[0,0.1]])
    labels = ['A','A','B','B']
    return group, labels

def file2matrix(filename):
    """
    从文件中读入训练数据,并存储为矩阵

    str.strip([chars]);
    chars -- 移除字符串头尾指定的字符序列。
    返回移除字符串头尾指定的字符序列生成的新字符串。
    此函数中: line.strip()去除首尾空格 ,回车


    例如 str = "Line1-abcdef \nLine2-abc \nLine4-abcd";
    line.split('\t')按tab键分割字符,之后得到
    ['Line1-abcdef', 'Line2-abc', 'Line4-abcd']


    """
    fr = open(filename)
    numberOfLines = len(fr.readlines())         #get the number of lines in the file
    returnMat = zeros((numberOfLines,3))        #prepare matrix to return
    classLabelVector = []                       #prepare labels return   

    index = 0
    for line in fr.readlines():
        line = line.strip()
        listFromLine = line.split('\t')
        returnMat[index,:] = listFromLine[0:3]  # 表示0,1,2 column 见下面举例图
        classLabelVector.append(int(listFromLine[-1]))
        index += 1
    return returnMat,classLabelVector


def autoNorm(dataSet):
    minVals = dataSet.min(0)
    maxVals = dataSet.max(0)
    ranges = maxVals - minVals
    normDataSet = zeros(shape(dataSet))
    m = dataSet.shape[0]
    normDataSet = dataSet - tile(minVals, (m,1))
    normDataSet = normDataSet/tile(ranges, (m,1))   #element wise divide
    return normDataSet, ranges, minVals

def datingClassTest():
    hoRatio = 0.50      #hold out 10%
    datingDataMat,datingLabels = file2matrix('datingTestSet2.txt')       #load data setfrom file
    normMat, ranges, minVals = autoNorm(datingDataMat)
    m = normMat.shape[0]
    numTestVecs = int(m*hoRatio)
    errorCount = 0.0
    for i in range(numTestVecs):
        classifierResult = classify0(normMat[i,:],normMat[numTestVecs:m,:],datingLabels[numTestVecs:m],3)
        print("the classifier came back with: %d, the real answer is: %d" % (classifierResult, datingLabels[i]))
        if (classifierResult != datingLabels[i]): errorCount += 1.0
    print("the total error rate is: %f" % (errorCount/float(numTestVecs)))
    print (errorCount)

def img2vector(filename):
    returnVect = zeros((1,1024))
    fr = open(filename)
    for i in range(32):
        lineStr = fr.readline()
        for j in range(32):
            returnVect[0,32*i+j] = int(lineStr[j])
    return returnVect


    '''
    os.listdir(path),path -- 需要列出的目录路径,返回指定路径下的文件和文件夹列表。
    '''
def handwritingClassTest():
    hwLabels = []
    trainingFileList = listdir('trainingDigits')           #load the training set
    m = len(trainingFileList)                              #number 
    trainingMat = zeros((m,1024))
    for i in range(m):                                      #遍历
        fileNameStr = trainingFileList[i]
        fileStr = fileNameStr.split('.')[0]     #take off .txt
        classNumStr = int(fileStr.split('_')[0])
        hwLabels.append(classNumStr)
        trainingMat[i,:] = img2vector('trainingDigits/%s' % fileNameStr)#把 列向量合并成矩阵
    testFileList = listdir('testDigits')        #iterate through the test set
    errorCount = 0.0
    mTest = len(testFileList)
    for i in range(mTest):
        fileNameStr = testFileList[i]
        fileStr = fileNameStr.split('.')[0]     #take off .txt
        classNumStr = int(fileStr.split('_')[0])
        vectorUnderTest = img2vector('testDigits/%s' % fileNameStr)
        classifierResult = classify0(vectorUnderTest, trainingMat, hwLabels, 3)
        print ("the classifier came back with: %d, the real answer is: %d" % (classifierResult, classNumStr))
        if (classifierResult != classNumStr): errorCount += 1.0
    print("\nthe total number of errors is: %d" % errorCount)
    print("\nthe total error rate is: %f" % (errorCount/float(mTest)))
  • Simple classification of kNN algorithm in python
    切换到kNN.py目录下import kNN
    这里写图片描述

note
Python3.5中:iteritems变为items

  • Parsing data from a text file
    先看一下数据:
    这里写图片描述
    运行代码:
>>> import kNN
>>> datingDataMat,datingLabels=kNN.file2matrix('datingTestSet2.txt')

这里写图片描述
# 表示0,1,2 column 见下面举例图note
这里写图片描述

  • Hand writing kNN Test
import kNN
kNN.handwritingClassTest()

这里写图片描述



参考

猜你喜欢

转载自blog.csdn.net/dongyanwen6036/article/details/81712484
今日推荐