手写识别系统

目的

采用k-近邻算法实现手写识别系统。这里采用0和1组成数字0-9的形状,再用算法对这些形状进行识别,来分辨出形状属于0-9那个数字。并计算出k-近邻算法识别手写数字的错误率。

数据说明

数据来自《机器学习实战》,分为测试集和训练集。单个数据如下图所示,表示数据0。 0的数字图

算法过程

  1. 收集数据:提供文本文件。
  2. 准备数据:编写函数classify0() ,将图像格式转换为分类器使用的制格式。
  3. 分析数据:在Python命令提示符中检查数据,确保它符合要求。
  4. 训练算法:此步驟不适用于k-近邻算法。
  5. 测试算法:编写函数使用提供的部分数据集作为测试样本,测试样本与非测试样本的区别在于测试样本是已经完成分类的数据,如果预测分类与实际类别不同,则标记为一个错误。
  6. 使用算法

准备数据:将图像转换为测试向量

def img2vector(filename):
    """
    该函数创建1*1024的NumPy数组,然后打开给定的文件,循环读出文件的前32行。
     并将每行的头32个字符值存储在NumPy数组中,然后返回数组
    """
    return_vect = np.zeros((1, 1024))
    fr = open(filename)
    for i in range(32):
        line_str = fr.readline()
        for j in range(32):
            return_vect[0, 32*i+j] = int(line_str[j])
    return return_vect

测试算法:使用k-近邻算法识别手写数字

def hand_writing_class_test():
    """
    手写数字识别系统的测试代码
    """
    hw_labels = []
    traing_file_list = os.listdir('../Data/Ch02/trainingDigits')  # 获取目录内容
    m = len(traing_file_list)
    matrix_of_training = np.zeros((m, 1024))
    for i in range(m):
        # 从文件名解析分类数字
        file_name_str = traing_file_list[i]
        file_str = file_name_str.split('.')[0]
        class_num_str = int(file_str.split('_')[0])

        hw_labels.append(class_num_str)
        matrix_of_training[i, :] = img2vector('../Data/Ch02/trainingDigits/%s' % file_name_str)
    test_file_list = os.listdir('../Data/Ch02/testDigits')
    error_count = 0.0
    m_test = len(test_file_list)
    for i in range(m_test):
        file_name_str = test_file_list[i]
        file_str = file_name_str.split('.')[0]
        class_num_str = int(file_str.split('_')[0])

        vector_under_test = img2vector('../Data/Ch02/testDigits/%s' % file_name_str)
        classifier_result = classify0(vector_under_test, matrix_of_training,
                                      hw_labels, 3)
        print("the clasifier came back with: %d, the real answer is %d" % (classifier_result, class_num_str))
        if classifier_result != class_num_str:
            error_count += 1.0
        print("\nthe total number of errors is %d" % error_count)
        print("\nthe total error rate is: %f" % (error_count/float(m_test)))

完整代码

# -*- coding: utf-8 -*-
# @Function :  使用k-近邻算法识别手写数字
import numpy as np
import os
import operator


def classify0(in_x, data_set, labels, k):
    """
    k-近邻算法
    :param in_x: 用于分类的输入向量X
    :param data_set: 输入的训练样本集data_set
    :param labels: 标签向量,其元素数目与矩阵data_set的行数相同
    :param k: 选择最近邻居的数目
    :return: 发生频率最高的元素标签
    """
    dataset_size = data_set.shape[0]

    # 原型:numpy.tile(A,reps)
    # tile共有2个参数,A指待输入数组,reps则决定A重复的次数。整个函数用于重复数组A来构建新的数组。
    # 计算距离,欧式距离公式:sqrt(pow(xA0-xB0, 2) + pow(xA1-xB1, 2))
    diff_mat = np.tile(in_x, (dataset_size, 1)) - data_set
    sq_diff_mat = diff_mat ** 2
    sq_distances = sq_diff_mat.sum(axis=1)

    distances = sq_distances ** 0.5
    # numpy.argsort() 返回排好序的序列的索引
    sorted_dist_indicies = distances.argsort()

    class_count = {}
    # 选择距离最小的k个节点
    for i in range(k):
        vote_I_label = labels[sorted_dist_indicies[i]]
        class_count[vote_I_label] = class_count.get(vote_I_label, 0) + 1

    sorted_class_count = sorted(class_count.items(), key=operator.itemgetter(1), reverse=True)
    return sorted_class_count[0][0]

def img2vector(filename):
    """
    该函数创建1*1024的NumPy数组,然后打开给定的文件,循环读出文件的前32行。
     并将每行的头32个字符值存储在NumPy数组中,然后返回数组
    """
    return_vect = np.zeros((1, 1024))
    fr = open(filename)
    for i in range(32):
        line_str = fr.readline()
        for j in range(32):
            return_vect[0, 32*i+j] = int(line_str[j])
    return return_vect


def hand_writing_class_test():
    """
    手写数字识别系统的测试代码
    """
    hw_labels = []
    traing_file_list = os.listdir('../Data/Ch02/trainingDigits')  # 获取目录内容
    m = len(traing_file_list)
    matrix_of_training = np.zeros((m, 1024))
    for i in range(m):
        # 从文件名解析分类数字
        file_name_str = traing_file_list[i]
        file_str = file_name_str.split('.')[0]
        class_num_str = int(file_str.split('_')[0])

        hw_labels.append(class_num_str)
        matrix_of_training[i, :] = img2vector('../Data/Ch02/trainingDigits/%s' % file_name_str)
    test_file_list = os.listdir('../Data/Ch02/testDigits')
    error_count = 0.0
    m_test = len(test_file_list)
    for i in range(m_test):
        file_name_str = test_file_list[i]
        file_str = file_name_str.split('.')[0]
        class_num_str = int(file_str.split('_')[0])

        vector_under_test = img2vector('../Data/Ch02/testDigits/%s' % file_name_str)
        classifier_result = classify0(vector_under_test, matrix_of_training,
                                      hw_labels, 3)
        print("the clasifier came back with: %d, the real answer is %d" % (classifier_result, class_num_str))
        if classifier_result != class_num_str:
            error_count += 1.0
        print("\nthe total number of errors is %d" % error_count)
        print("\nthe total error rate is: %f" % (error_count/float(m_test)))


if __name__ == '__main__':
    hand_writing_class_test()

结果

输出结果

猜你喜欢

转载自my.oschina.net/chenmoxuan/blog/1820719