KNN手写数字识别

import numpy as np
import matplotlib .pyplot as plt 
from sklearn.neighbors import KNeighborsClassifier

 读取样本数据,图片

样本数据的提取

  • 特征:每一张图片对应的numpy数组
  • 目标:0,1,2,3,4,5,6,7,8,9
feature = []
target = []
for i in range(10):#i:0-9表示的是文件夹的名称
    for j in range(1,501):#j:1-500表示的是图片的名称的一部分
        imgPath = './data/'+str(i)+'/'+str(i)+'_'+str(j)+'.bmp'
        img_arr = plt.imread(imgPath)
        feature.append(img_arr)
        target.append(i)
feature = np.array(feature) #feature是一个三维的数组
target = np.array(target)
feature.shape
#(5000, 28, 28)

 feature目前是三维的numpy数组。必须变形成二维的才可以作为特征数据

feature = feature.reshape(5000,784)

 进行样本数据的打乱,并保证数据对应

np.random.seed(10)
np.random.shuffle(feature)
np.random.seed(10)
np.random.shuffle(target)
对样本数据进行拆分
测试数据
训练数据
knn = KNeighborsClassifier(n_neighbors=9)
knn.fit(x_train,y_train)
knn.score(x_test,y_test)
#对模型进行测试
print('真实的结果:',y_test)
print('模型分类的结果:',knn.predict(x_test))

 保存训练好的模型

from sklearn.externals import joblib
#保存
joblib.dump(knn,'./knn.m')
#读取
knn = joblib.load('./knn.m')
knn
将外部图片带入模型进行分类的测试
img_arr = plt.imread('./数字.jpg')
plt.imshow(img_arr)

 图片剪切

eight_arr = img_arr[175:240,85:135]
plt.imshow(eight_arr)
eight_arr.shape
#(65, 50, 3)
#模型只可以测试类似于测试数据中的特征数据
#将8对应的图片进行降维(65, 50, 3)降低成(784,)
eight_arr = eight_arr.mean(axis=2)
eight_arr.shape
#(65, 50)
#进行图片像素的等比例压缩
import scipy.ndimage as ndimage
eight_arr = ndimage.zoom(eight_arr,zoom=(28/65,28/50))
eight_arr = eight_arr.reshape(1,784)
eight_arr.shape
#(1, 784)
knn.predict(eight_arr)

 代码以及样本数据查看连接:https://github.com/dylan3714/-

猜你喜欢

转载自www.cnblogs.com/dylan123/p/12717056.html