机器学习实战KNN----手写体识别

    最近在学习机器学习,觉得机器学习特别有趣,现在将KNN算法做一个总结。

一、简介

        KNN是一种有监督的分类算法,思想很简单,就是通过测量预测数据与训练数据之间的距离,选取前K个距离最短(或距离在K以内)的训练样本,在得到的训练样本中取标签数占的最多的类别即为预测类别。

        KNN的操作步骤:

  1. 计算testdata与每一个trainingdata之间的距离
  2. 按距离递增次序排序
  3. 选取与当前点距离最小的k个点
  4. 确定前k个点所在类别出现的频率
  5. 返回前k个点出现频率最高的类别作为当前点的预测类别

二、优缺点

        优点:精度高、对异常值敏感、无数据输入假定 

        缺点:计算复杂度高、空间复杂度高、测试时间长

三、实战

        通过阅读机器学习实战的knn章节,我自己也将书上的手写体识别(分类)写了一遍,我已将数据集和代码已上传至https://github.com/ZhangPengFe/KNN_HandWrite ,希望你们也能自己写一遍,看和写的收获完全不一样

        由于代码很简单,我就不解释了,如果有不懂的代码段,大家可以留言,我看到会和大家一起讨论。

import numpy as np
import os


def load_data(file_dir, file_name):
    file_path = os.path.join(file_dir, file_name)
    data = np.zeros((1, 1024))
    with open(file_path, 'r') as f:
        for i in range(32):
            line = f.readline()
            for j in range(32):
                data[0, i*32+j] = int(line[j])
    label = int(file_name.split('_')[0])
    return data, label


def handwrite():
    train_dir = 'trainingDigits'
    train_filename = os.listdir(train_dir)  # return a list, which include all file name under the train_dir
    train_filenum = len(train_filename)
    train_data = np.zeros((train_filenum, 1024))
    train_label = []
    # read train data and label
    for i in range(train_filenum):
        train_data[i, :], label = load_data(train_dir, train_filename[i])
        train_label.append(label)

    test_dir = 'testDigits'
    test_filename = os.listdir(test_dir)
    test_filenum = len(test_filename)
    a_test_data = np.zeros((1, 1024))
    real = 0
    for i in range(test_filenum):
        a_test_data, test_label = load_data(test_dir, test_filename[i])
        diff_data = np.tile(a_test_data, (train_filenum, 1)) - train_data
        diff_data_square = diff_data**2
        temp = diff_data_square.sum(axis=1)
        distance = temp**0.5
        sort = distance.argsort()
        predict = {}
        for j in range(10):
            predict[j] = 0
        for j in range(3):          # k of knn is 3
            predict[train_label[sort[j]]] += 1
        pre_result = max(predict, key=predict.get)
        print("the classifier came back with: %d, the real answer is: %d" % (pre_result, test_label))
        if pre_result == test_label:
            real += 1
    print(real/test_filenum)


if __name__ == '__main__':
    handwrite()

猜你喜欢

转载自blog.csdn.net/zpf123456789zpf/article/details/87971039
今日推荐