python基于numpy实现knn分类器并运用于手写数字识别

在写代码的时候,我一开始将训练、测试集的标签转化为了one_hot表示方式,忽视了这个在本题中时不需要one_hot表示方式的,所以花了挺多时间调试这个,本段代码开头调用了keras模块的datasets,仅仅用于加载训练集和测试集,其他的部分都是基于numpy模块和前一篇博文实现的knn分类器,有关knn分类器的实现大家参考我的这一片博文:
https://blog.csdn.net/weixin_43141320/article/details/105315522
先放代码:

import numpy as np
import KNN
from keras.datasets import mnist


# def one_hot_(x, depth):
#     x = np.array(x).reshape((1, -1))
#     labels = np.zeros((x.shape[1], depth))
#     labels[np.arange(x.shape[1]), x[0]] = 1
#     return labels


def loadDataset():
    (train_data, train_labels), (test_data, test_labels) = mnist.load_data()
    train_data = train_data[: 10000]
    train_labels = train_labels[: 10000]
    test_data = test_data[45: 100]
    test_labels = test_labels[45: 100]
    # print("all the shape:", train_data.shape, train_labels.shape, test_labels.shape, test_data.shape)
    train_data = train_data.reshape((-1, 28, 28, 1))
    train_data = train_data.astype('float') / 255

    test_data = test_data.reshape((-1, 28, 28, 1))
    test_data = test_data.astype('float') / 255
    # train_num = train_data.shape[0]
    # test_num = test_data.shape[0]

    # 标签值不要转为one_hot表示方式,不然后面会出错。
    # train_labels = to_categorical(train_labels)
    # test_labels = to_categorical(test_labels)
    # print("all the shape:", train_data.shape, train_labels.shape, test_labels.shape, test_data.shape)
    # train_labels = one_hot_(train_labels, 10)
    # test_labels = one_hot_(test_labels, 10)

    return train_data, train_labels, test_data, test_labels


def handWritingTest(train_x, train_y, test_x, test_y):
    cls = KNN.KNNClassifier(k=3)
    error = 0
    result = cls.classify(test_x, train_x, train_y)

    for i in range(len(result)):
        if result[i] != test_y[i]:
            error += 1
    precision_rate = 1 - error / len(test_y)
    print("precision: ", precision_rate)


def main():
    train_data, train_labels, test_data, test_labels = loadDataset()
    handWritingTest(train_data, train_labels, test_data, test_labels)


if __name__ == "__main__":
    main()

代码很简单,因为主要的分类器不在这段代码实现的,现在说说各个函数的作用:
loadDataset()函数调用keras 的datasets来加载数据集,由于mnist的数据集很多,为了减少训练时间,这里的训练集选取前10000张图片,测试集选的是第46到第100张图片。
handWritingTest()函数接收传进来的数据集,调用knn分类器,并计算精确值,最终返回预测的精度。
在创建KNN.KNNClassifier(k=3)类的时候,除了有参数k,还有一个参数processBar,默认值为True,作用是显示训练的进度。
先看看训练的效果:
在这里插入图片描述

在目前的训练集和测试集的前提下,预测精度达到0.95左右,当然我们可以通过增加训练集,来获取更优越的成就感~

发布了79 篇原创文章 · 获赞 8 · 访问量 3319

猜你喜欢

转载自blog.csdn.net/weixin_43141320/article/details/105328590
今日推荐