机器学习实战(4)—— kNN实战手写识别系统

我:终于到周末了,可以休息一下了!!!来几把LOL!!!

(叮铃…叮铃…叮铃…)

我:喂,老板啊?怎么啦

老板:小韩啊,在家休息吗?

我:是啊。

老板:别休息啦,来加个班,用上次你写的kNN,做一个手写识别系统,训练集和测试集我都发你邮箱了!周日晚上给我!

我:(What???大周末的,你让我加班,老子不干了!)行,保证写出来!

行了行了,周末不休息了,开工!

这次我们要构建一个手写识别系统,为了简单,我们就只识别0-9。需要识别的数字已经用图形处理软件,处理成具有相同的色彩和大小:宽高是32像素×32像素的黑白图像。尽管采用文本格式存储图像不能有效地利用内存空间,但是为了方便我们的理解,我们还是将图像转换为文本格式。示例如下:

然后,我们来看一下,使用kNN构造手写识别系统的步骤:

  1. 收集数据:提供文本文件。
  2. 准备数据:编写函数classify0(),将图像格式转换为分类器使用的list格式。
  3. 分析数据:在Python命令提示符中检查数据,确保它符合要求。
  4. 训练算法:此步骤不适用于k-近邻算法。
  5. 测试算法:编写函数使用提供的部分数据集作为测试样本,测试样本与非测试样本的区别在于测试样本是已经完成分类的数据,如果预测分类与实际类别不同,则标记为一个错误。
  6. 使用算法:本例没有完成此步骤,若你感兴趣可以构建完整的应用程序,从图像中提取数字,并完成数字识别,美国的邮件分拣系统就是一个实际运行的类似系统。

2.3.1 准备数据:将图像转换为测试向量

老板给的训练集在目录trainingDigits中,其中包含了大约2000个例子,每个数字大概有200个样本。测试集在目录testDigits中,其中大约900个测试数据。截图如下:

每个文本文件名称下划线前的数字代表这个文本文件所代表数字。比如说0_8.txt代表的是数字0的第9个样本(从0开始计数)。

为了使用我们先前编写好的分类器,我们必须将图像格式化处理为一个向量。我们将一个32×32的二进制图像矩阵转换为1×1024的向量。

好了,代码走起来!我们继续在kNN.py中编写函数img2vector,代码如下:

def img2vector(filename):
    returnVect = zeros((1, 1024))
    fr = open(filename)
    for i in range(32):
        lineStr = fr.readline()
        for j in range(32):
            returnVect[0, 32 * i + j] = int(lineStr[j])
    return returnVect

代码很简单,就是将原来32×32转换成1×1024,这里我也就不多说什么了。大家可以自己去测试一下效果。

2.3.2 使用k-近邻算法识别手写数字

上一节我们已经把数据处理成我们想要的格式了,那么接下来我们就可以将这些数据丢到分类器里了。直接来看代码:

def handwritingClassTest():
    # 1.初始化我们所需要的数据
    hwLabels = []
    trainingFileList = os.listdir('trainingDigits')  # 这里需要我们提前导入os模块,listdir可以列出给定目录下的文件名
    m = len(trainingFileList)  # 获得训练样本数目
    trainingMat = zeros((m, 1024))  # 构造m×1024的矩阵
    
    # 2.循环遍历训练集中的每个文件,生成每个数字的向量信息,保存在trainingMat中
    for i in range(m):
        fileNameStr = trainingFileList[i]  # 获得文件名
        fileStr = fileNameStr.split('.')[0]
        classNumStr = int(fileStr.split('_')[0])  # 获得该文件所代表的数字
        hwLabels.append(classNumStr)  # 将文件所代表的数字其存放在类别标签中
        trainingMat[i, :] = img2vector('trainingDigits/%s' % fileNameStr)  # 数据转换
    
    # 3.遍历测试数据文件夹,使用kNN进行测试。
    testFileList = os.listdir('testDigits')
    errorCount = 0.0
    mTest = len(testFileList)  # 获得测试样本数目
    for i in range(mTest):
        fileNameStr = testFileList[i]  # 获得文件名
        fileStr = fileNameStr.split('.')[0]
        classNumStr = int(fileStr.split('_')[0])  # 获得该文件所代表的数字
        vectorUnderTest = img2vector('testDigits/%s' % fileNameStr)
        classifierResult = classify0(vectorUnderTest, trainingMat, hwLabels, 3)  # 分类
        print('the classifier came back with: %d, the real answer is: %d' % (classifierResult, classNumStr))
        if classifierResult != classNumStr:
            errorCount += 1.0

    print('\nthe total number of errors is: %d' % errorCount)
    print('\nthe total error rate is: %f' % (errorCount / float(mTest)))

上面代码也不难,每一步的具体含义我都给大家写在注释中了,所以我也就不多说了。

依赖于机器速度,加载数据集可能要花费很长时间,然后函数开始依次测试每个文件,我们直接来看输出的结果:

我们使用k-近邻算法识别手写数字数据集,错误率为1.2%。

改变变量k的值、修改函数handwritingClassTest随机选取训练样本、改变训练样本的数目,都会对k-近邻算法的错误率产生影响,感兴趣的话可以改变这些变量值,观察错误率的变化。

但是,我们需要注意的是,实际使用这个算法时,算法的执行效率并不高。原因如下:

  1. 算法需要为每个测试向量做2000次距离计算,每个距离计算包括了1024个维度浮点运算,总计要执行900次,
  2. 此外,我们还需要为测试向量准备2MB的存储空间。

2.4 小结

kNN的理论、实战,我们就讲到这里了,下面我们来总结一下:

  1. k-近邻算法是分类数据最简单最有效的算法,我们通过两次实战讲述了如何使用k-近邻算法构造分类器。
  2. k-近邻算法是基于实例的学习,使用算法时我们必须有接近实际数据的训练样本数据。
  3. k-近邻算法必须保存全部数据集,如果训练数据集的很大,必须使用大量的存储空间。此外, 由于必须对数据集中的每个数据计算距离值,实际使用时可能非常耗时。
  4. k-近邻算法的另一个缺陷是它无法给出任何数据的基础结构信息,因此我们也无法知晓平均实例样本和典型实例样本具有什么特征。

好了,k-近邻算法我们就讲到这里,因为是最基础的,所以用了比较多的篇幅,希望大家能够慢慢看完,对机器学习先有一个感性的认识。

机器学习的路还很长,加油,冲冲冲!!!


最后,还是熟悉的配方!

欢迎大家关注我的公众号,有什么问题也可以给我留言哦!

猜你喜欢

转载自blog.csdn.net/RabitMountain/article/details/85466304