kdtree 手写数字识别

版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/qq_27668313/article/details/78994064

kdTree的构建和搜索,参见以下博客,原理非常详细:

https://www.cnblogs.com/21207-iHome/p/6084670.html

代码采用python3.X,数据集是从kaggle上的digit recognizer项目下载。

建议先采用data=[[2, 3], [5, 4], [9,6], [4,7], [8,1], [7,2]] 运行代码,这样就可以看见树的结构,也就能更好的看懂搜索树函数,这个数据是从《统计学习方法》P42摘取。

from numpy import *
import numpy as np
import operator
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.decomposition import PCA
import time

'''使用kdTree以减少kNN运算量,采用最近邻搜索
使用12000样本数结果如下:
Data size:  (12000, 784)
create tree time:  329.07522320747375
accuracy:  0.9488888888888889
time:  1352.6775617599487'''

time_start = time.time()


def loadDataSet(filedir):
    DateSet = pd.read_csv(filedir)
    Data = DateSet.iloc[1:12001, 1:].as_matrix()  # dataframe格式转成元组
    print('Data size: ', shape(Data))
    label = DateSet.iloc[1:12001, :1].as_matrix()
    # 分配训练数据和测试数据
    train_data, test_data,  train_label, test_label = train_test_split(Data, label, test_size=0.3, random_state=0)
    return train_data, test_data, train_label, test_label


def createTree(Data):
    treeNode = {}
    # 计算每个维度的方差,选取方差最大的维度作为划分轴
    nums = len(Data)    # 样本数
    if nums == 0:
        return None
    dims = len(Data[0])  # 样本维度
    max_var = 0
    split_dim = 0   # 划分轴
    for dim in range(dims):
        temp = []
        for num in Data:
            temp.append(num[dim])
        variance = var(temp)
        if variance > max_var:
            max_var = variance
            split_dim = dim

    # 对数据进行划分
    sort_Data = sorted(Data, key=lambda d: d[split_dim])   # 用split_dim维度的数据对Data排序
    # print('split_dim: ', split_dim)
    # print('sort_Data: ', sort_Data)
    split_point = nums//2    # 划分点
    # print('split_point: ', split_point)
    treeNode['split'] = split_dim
    # print('split_dim: ', split_dim)
    treeNode['median'] = sort_Data[split_point]     # 节点
    # 创建左右子树
    treeNode['left'] = createTree(sort_Data[:split_point])
    treeNode['right'] = createTree(sort_Data[split_point+1:])
    return treeNode

# 这个函数比较难理解,建议看懂搜索树原理再看代码,为了更好的看懂程序,建议使用如下数据运行代码,可视化# 程度更高	data = [[2,3], [5,4], [9,6], [4,7], [8,1], [7,2]]
def searchTree(tree, data):
    k = len(data)
    if tree is None:
        return [0]*k, float('inf')
    split_axis = tree['split']
    median_point = tree['median']
    if data[split_axis] <= median_point[split_axis]:
        nearestPoint, nearestDistance = searchTree(tree['left'], data)
    else:
        nearestPoint, nearestDistance = searchTree(tree['right'], data)
    nowDistance = np.linalg.norm(np.array(data)-np.array(median_point))
    if nowDistance < nearestDistance:
        nearestDistance = nowDistance
        nearestPoint = median_point.copy()
    splitDistance = abs(data[split_axis] - median_point[split_axis])
    if splitDistance > nearestDistance:
        return nearestPoint, nearestDistance
    else:
        if data[split_axis] <= median_point[split_axis]:
            nextTree = tree['right']
        else:
            nextTree = tree['left']
        nearPoint, nearDistanc = searchTree(nextTree, data)
        if nearDistanc < nearestDistance:
            nearestDistance = nearDistanc
            nearestPoint = nearPoint.copy()
        return nearestPoint, nearestDistance


def kNN_kdTree(train_data, test_data, train_label, test_label, tree):
    train = train_data.tolist()
    nums = shape(test_data)[0]
    right = 0
    for num in range(nums):
        point, d = searchTree(tree, test_data[num])
        pt = point.tolist()
        addr = train.index(pt)
        if train_label[addr] == test_label[num]:
            right += 1
    accuracy = right/nums
    return accuracy


train_data, test_data, train_label, test_label = loadDataSet('E:\ProgramData\Python3Project\Digit Recognizer/train.csv')
# print('train_data: ', train_data)
tree = createTree(train_data)
print('Create Tree Done!')
# print(tree)
time_createTree = time.time()
print('create tree time: ', time_createTree - time_start)
# train_data = train_data.tolist()
rate = kNN_kdTree(train_data, test_data, train_label, test_label, tree)
print('accuracy: ', rate)
time_end = time.time()
print('time: ', time_end - time_start)
 
 

猜你喜欢

转载自blog.csdn.net/qq_27668313/article/details/78994064