基于K-近邻算法的手写字符识别

收集数据集,包括trainingDigits和testDigits文件夹,文件夹下是字符图像的二进制文本文件,大小为32*32,如下图所示:

1.将二进制图像文本文件转换为一行向量

首先介绍listdir,它能返回文件夹路径下的所有子文件:

import os
from os import listdir
listdir(r'C:\Users\Administrator\Desktop\python学习\6.pandas学习')

运行结果:
['.ipynb_checkpoints',
 'admissions.csv',
 'HappyFish.jpg',
 'pandas处理丢失数据.ipynb',
 'pandas多级索引.ipynb']

 将二进制图像转换为一行向量

#将32*32的二进制图像转换为1*1024的向量
import numpy  as np
def img2vector(filename):
    vector = np.zeros((1,1024))
    f = open(filename)
    for i in range(32):
        line = f.readline()
        for j in range(32):
            vector[0,i*32+j] = int(line[j])
    return vector

2.构建KNN分类器模型

def classifier(inX,dataSet,labels,K):
    row = len(dataSet)
    subtract = np.tile(inX,(row,1)) - dataSet
    square = subtract**2
    distance = (np.sum(square,axis=1))**0.5
    sort_index = distance.argsort()
    classCount = {}
    for i in range(K):
        key = labels[sort_index[i]]
        classCount[key] = classCount.get(key,0)+1
    result = sorted(classCount.items(),key=lambda x:x[1],reverse=True)
    return result[0][0]

3.评估模型

#测试错误率,评估模型
def handwriting_character_recognition():
    train_dir = listdir('trainingDigits')
    train_count = len(train_dir)
    dataSet = np.zeros((train_count,1024))
    labels = []
    for i in range(train_count):
        dataSet[i] = img2vector('trainingDigits/%s'%(train_dir[i]))
        label = train_dir[i].split('.')[0]
        label = int(label.split('_')[0])
        labels.append(label)
        
    test_dir = listdir('testDigits')
    test_count = len(test_dir)
    k = 0
    for i in range(test_count):
        real_label = int((test_dir[i].split('.')[0]).split('_')[0])
        pred_label = classifier(img2vector('testDigits/%s'%(test_dir[i])),dataSet,labels,3)
        print("real_label is %d,pred_label is %d"%(real_label,pred_label))
        if(pred_label!=real_label):
            k +=1
    print("The error rate is %.2f"%(k/test_count)) 

handwriting_character_recognition()
运行结果如下:
...
real_label is 9,pred_label is 9
real_label is 9,pred_label is 9
real_label is 9,pred_label is 9
real_label is 9,pred_label is 9
real_label is 9,pred_label is 9
real_label is 9,pred_label is 9
The error rate is 0.01

猜你喜欢

转载自blog.csdn.net/qq_24946843/article/details/83861423