K近邻实战手写数字识别

1、导包

import numpy as np
import operator
from os import listdir
from sklearn.neighbors import KNeighborsClassifier as KNN

%config ZMQInteractiveShell.ast_node_interactivity='all'

2、定义将图像转换成向量的函数

"""
函数说明:将32x32的二进制图像转换成1x1024向量

Parameters:
    filename - 文件名
Returns:
    returnVect - 返回的二进制图像的1x1024向量
"""
def img2vector(filename):
    # 创建1x1024零向量
    returnVect = np.zeros((1, 1024))
    # 打开文件
    fr = open(filename)
    # 按行读取
    for i in range(32):
        # 读一行数据
        lineStr = fr.readline()
        # 每一行的前32个元素一次添加到returnVect中
        for j in range(32):
            returnVect[0, 32*i + j] = int(lineStr[j])
    # 返回转换后的1x1024向量
    return returnVect

3、定义手写数字识别系统函数

"""
函数说明:手写数字分类测试

Parameters:
    无
Returns:
    无
"""
def handwritingClassTest():
    # 训练集的Labels
    hwLabels = []
    # 返回trainingDigits目录下的文件名
    trainingFileList = listdir('trainingDigits')
    # 返回文件夹下的文件的个数
    m = len(trainingFileList)
    # 初始化训练的Mat矩阵,训练集
    trainingMat = np.zeros((m, 1024))
    # 从文件集中解析出训练集的类别
    for i in range(m):
        # 获得文件的名字
        fileNameStr = trainingFileList[i]
        # 获得分类的数字
        classNumber = int(fileNameStr.split('_')[0])
        # 将获得的类别添加到hwLabels中
        hwLabels.append(classNumber)
        # 将每一个文件的1x1024数据存储到trainingMat矩阵中
        trainingMat[i, :] = img2vector('trainingDigits/%s' % (fileNameStr))
    # 构建KNN分类器
    neigh = KNN(n_neighbors=3, algorithm='auto')
    # 拟合模型,trainingMat为训练矩阵,hwLabels为对应的标签
    neigh.fit(trainingMat, hwLabels)
    # 返回testDigits目录下的文件列表
    testFileList = listdir('testDigits')
    # 错误检查计数
    errorCount = 0.0
    # 测试数据的数量
    mTest = len(testFileList)
    # 从文件中解析出测试集的类别并进行分类测试
    for i in range(mTest):
        # 获得文件的名字
        fileNameStr = testFileList[i]
        # 获得分类的数字
        classNumber = int(fileNameStr.split('_')[0])
        # 获得测试集的1x1024向量,用于训练
        vectorUnderTest = img2vector('testDigits/%s' % (fileNameStr))
        # 获得预测结果
        classifierResult = neigh.predict(vectorUnderTest)
        # 打印
        print('分类返回结果为%d\t真实结果为%d' % (classifierResult, classNumber))
        if(classifierResult != classNumber):
            errorCount += 1.0
    print('总共错了%d个数据\n错误率为%f%%' %(errorCount, errorCount/mTest * 100))

4 运行结果

if __name__ == "__main__":
    handwritingClassTest()

数据集地址:

链接:https://pan.baidu.com/s/1-F2LyVh63i4yjIwweTYjNg
提取码:3gsa

参考:

1、《机器学习实战》书籍

2、https://github.com/apachecn/AiLearning

3、https://cuijiahua.com/blog/2017/11/ml_1_knn.html

4、深度之眼机器学习实战训练营课后作业(http://www.deepshare.net/

猜你喜欢

转载自www.cnblogs.com/WJZheng/p/11258147.html