机器学习实战笔记之knn

from numpy import *
import operator
import matplotlib
import matplotlib.pyplot as plt
from os import listdir


def createDataSet():
    group = array([[1.0,1.1],[1.0,1.0],[0,0],[0,0.1]])
    labels = ['A','A','B','B']
    return group,labels


def classify0(inX, dataSet, labels, k):
    #inX为要分类的数据,dataSet为训练集,labels为训练集的标签,k为k近邻
    #计算行数
    dataSetSize = dataSet.shape[0]
    #print("dataSetSize",dataSetSize)
    #tile重复dataSetSize行
    diffMat = tile(inX, (dataSetSize,1)) - dataSet
    #print("diffMat",diffMat)
    sqDiffMat = diffMat ** 2
    #print("sqDiffMat",sqDiffMat)
    #列向量相加
    sqDistances = sqDiffMat.sum(axis = 1)
    #print("sqDistances",sqDistances)
    distances = sqDistances ** 0.5
    #print("distances",distances)
    
    #从小到大排序的序号sortedDistIndices [2 3 1 0]
    sortedDistIndices = distances.argsort()
    #print("sortedDistIndices",sortedDistIndices)
    classCount = {}
    
    for i in range(k):
        voteIlabel = labels[sortedDistIndices[i]]
        classCount[voteIlabel] = classCount.get(voteIlabel,0) + 1
    #classCount {'B': 2, 'A': 1}
    #print("classCount",classCount)
    
    #按第2个元素降序排序sortedClassCount [('B', 2), ('A', 1)]
    sortedClassCount = sorted(classCount.items(), key = operator.itemgetter(1), reverse = True)
    #print("sortedClassCount",sortedClassCount)
    return sortedClassCount[0][0]
        
def file2matrix(filename):
    fr = open(filename)
    arrayOLines = fr.readlines()
    numberOfLines = len(arrayOLines)
    #print("numberOfLines", numberOfLines)
    
    returnMat = zeros((numberOfLines, 3))
    #print("returnMat", returnMat)
    
    classLabelVector = []
    index = 0
    
    for line in arrayOLines:
        line = line.strip()
        listFromLine = line.split('\t')
        #print("listFromLine", listFromLine)
        returnMat[index, :] = listFromLine[0:3]
        #print("listFromLine[-1]", listFromLine[-1])
        classLabelVector.append(int(listFromLine[-1]))
        index += 1
        
    return returnMat, classLabelVector


def autoNorm(dataSet): 
    """归一化特征值"""
    #每一列的最小和最大值
    minVals = dataSet.min(0)
    maxVals = dataSet.max(0)   
    ranges = maxVals - minVals
    print("minVals", minVals)
    print("maxVals", maxVals)    
    print("ranges", ranges)
    
    #shape(dataSet) (1000, 3)
    normDataSet = zeros(shape(dataSet))
    # ~ print("shape(dataSet)", shape(dataSet))
    # ~ print("normDataSet", normDataSet)
    
    #m 1000
    m = dataSet.shape[0]
    print("m", m)
    
    #每一行减去每一列的最小值
    normDataSet = dataSet - tile(minVals, (m,1))
    #print("tile(minVals, (m,1))", tile(minVals, (m,1)))
    
    #newValue = (oldValue - min) / (max - min)
    normDataSet =  normDataSet / tile(ranges, (m,1))
    # ~ print("tile(ranges, (m,1))", tile(ranges, (m,1)))
    # ~ print("normDataSet", normDataSet)
    
    return normDataSet, ranges, minVals
    
def datingClassTest():
    hoRatio = 0.10
    datingDataMat, datingLabels = file2matrix('datingTestSet2.txt')  
    normMat, ranges, minVals = autoNorm(datingDataMat)
    m = normMat.shape[0]
    
    #测试集个数 10%
    numTestVecs = int(m * hoRatio)
    errorCount = 0.0
    for i in range(numTestVecs):
        classifierResult = classify0(normMat[i,:], normMat[numTestVecs:m, :], datingLabels[numTestVecs:m], 3)
        print("the classifier came back with %d, the real answer is %d"  %(classifierResult, datingLabels[i]))
        if(classifierResult != datingLabels[i]):
            errorCount += 1.0


    print("the total error rate is %f" %(errorCount / float(numTestVecs)))
           
def classifyPerson():
    """分类这个人是否适合dating"""
    resultList = ['not at all', 'in small doses', 'in large doses']
    percentTats = float(input("percentage of time spent playing video games?"))
    ffMiles = float(input("frequent flier miles earned per year?"))
    iceCream = float(input("liters of ice cream consumed per year?"))
    datingDataMat, datingLabels = file2matrix('datingTestSet2.txt')  
    normMat, ranges, minVals = autoNorm(datingDataMat)
    inArr = array([ffMiles, percentTats, iceCream])
    classifierResult = classify0((inArr-minVals)/ranges, normMat, datingLabels, 3)
    print("You will probably like this person:", resultList[classifierResult - 1])


def img2vector(filename):
    """图像转为向量"""
    returnVect = zeros((1, 1024))
    fr = open(filename)
    for i in range(32):
        lineStr = fr.readline()
        for j in range(32):
            returnVect[0, 32*i + j] = int(lineStr[j])
    return returnVect
        
def handwritingClassTest():
    hwLabels = []
    trainingFileList = listdir('trainingDigits')
    m = len(trainingFileList)
    trainingMat = zeros((m, 1024))
    
    for i in range(m):
        fileNameStr = trainingFileList[i]
        fileStr = fileNameStr.split('.')[0]
        classNumStr = int(fileStr.split('_')[0])
        hwLabels.append(classNumStr)
        trainingMat[i,:] = img2vector('trainingDigits/%s' %fileNameStr)
        
    testFileList = listdir('testDigits')
    errorCount = 0.0
    mTest = len(testFileList)
    
    for i in range(mTest):
        fileNameStr = testFileList[i]
        fileStr = fileNameStr.split('.')[0]
        classNumStr = int(fileStr.split('_')[0])
        vectorUnderTest = img2vector('testDigits/%s' %fileNameStr)
        
        classifierResult = classify0(vectorUnderTest, trainingMat, hwLabels, 3)
        #print("the classifier came back with: %d, the real answer is: %d" %(classifierResult, classNumStr))
        
        if(classifierResult != classNumStr):
            errorCount += 1.0
            
    print("\nthe total number of errors is: %d" %errorCount)
    print("\nthe total error rate is %f" %(errorCount/float(mTest)))
        

# datingDataMat, datingLabels = file2matrix('datingTestSet2.txt')
# print("datingDataMat", datingDataMat)
# #print("datingLabels", datingLabels)
#
# normDataSet, ranges, minVals = autoNorm(datingDataMat)
#datingClassTest()
#classifyPerson()
# ~ testVector = img2vector('testDigits/0_13.txt')
# ~ print(testVector[0,0:31])
# ~ print(testVector[0,32:63])
handwritingClassTest()


# ~ fig = plt.figure()
# ~ ax = fig.add_subplot(111)
# ~ ax.scatter(datingDataMat[:,0], datingDataMat[:,1], 15.0*array(datingLabels), 15.0*array(datingLabels))
# ~ plt.show()
        
# ~ group,labels = createDataSet()
# ~ print(group)
# ~ print(labels)
# ~ result = classify0([0,0], group, labels, 3)
# ~ print(result)

猜你喜欢

转载自blog.csdn.net/feidao84/article/details/80686772