版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/qq_27668313/article/details/78994064
kdTree的构建和搜索,参见以下博客,原理非常详细:
https://www.cnblogs.com/21207-iHome/p/6084670.html
代码采用python3.X,数据集是从kaggle上的digit recognizer项目下载。
建议先采用data=[[2, 3], [5, 4], [9,6], [4,7], [8,1], [7,2]] 运行代码,这样就可以看见树的结构,也就能更好的看懂搜索树函数,这个数据是从《统计学习方法》P42摘取。
from numpy import *
import numpy as np
import operator
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.decomposition import PCA
import time
'''使用kdTree以减少kNN运算量,采用最近邻搜索
使用12000样本数结果如下:
Data size: (12000, 784)
create tree time: 329.07522320747375
accuracy: 0.9488888888888889
time: 1352.6775617599487'''
time_start = time.time()
def loadDataSet(filedir):
DateSet = pd.read_csv(filedir)
Data = DateSet.iloc[1:12001, 1:].as_matrix() # dataframe格式转成元组
print('Data size: ', shape(Data))
label = DateSet.iloc[1:12001, :1].as_matrix()
# 分配训练数据和测试数据
train_data, test_data, train_label, test_label = train_test_split(Data, label, test_size=0.3, random_state=0)
return train_data, test_data, train_label, test_label
def createTree(Data):
treeNode = {}
# 计算每个维度的方差,选取方差最大的维度作为划分轴
nums = len(Data) # 样本数
if nums == 0:
return None
dims = len(Data[0]) # 样本维度
max_var = 0
split_dim = 0 # 划分轴
for dim in range(dims):
temp = []
for num in Data:
temp.append(num[dim])
variance = var(temp)
if variance > max_var:
max_var = variance
split_dim = dim
# 对数据进行划分
sort_Data = sorted(Data, key=lambda d: d[split_dim]) # 用split_dim维度的数据对Data排序
# print('split_dim: ', split_dim)
# print('sort_Data: ', sort_Data)
split_point = nums//2 # 划分点
# print('split_point: ', split_point)
treeNode['split'] = split_dim
# print('split_dim: ', split_dim)
treeNode['median'] = sort_Data[split_point] # 节点
# 创建左右子树
treeNode['left'] = createTree(sort_Data[:split_point])
treeNode['right'] = createTree(sort_Data[split_point+1:])
return treeNode
# 这个函数比较难理解,建议看懂搜索树原理再看代码,为了更好的看懂程序,建议使用如下数据运行代码,可视化# 程度更高 data = [[2,3], [5,4], [9,6], [4,7], [8,1], [7,2]]
def searchTree(tree, data):
k = len(data)
if tree is None:
return [0]*k, float('inf')
split_axis = tree['split']
median_point = tree['median']
if data[split_axis] <= median_point[split_axis]:
nearestPoint, nearestDistance = searchTree(tree['left'], data)
else:
nearestPoint, nearestDistance = searchTree(tree['right'], data)
nowDistance = np.linalg.norm(np.array(data)-np.array(median_point))
if nowDistance < nearestDistance:
nearestDistance = nowDistance
nearestPoint = median_point.copy()
splitDistance = abs(data[split_axis] - median_point[split_axis])
if splitDistance > nearestDistance:
return nearestPoint, nearestDistance
else:
if data[split_axis] <= median_point[split_axis]:
nextTree = tree['right']
else:
nextTree = tree['left']
nearPoint, nearDistanc = searchTree(nextTree, data)
if nearDistanc < nearestDistance:
nearestDistance = nearDistanc
nearestPoint = nearPoint.copy()
return nearestPoint, nearestDistance
def kNN_kdTree(train_data, test_data, train_label, test_label, tree):
train = train_data.tolist()
nums = shape(test_data)[0]
right = 0
for num in range(nums):
point, d = searchTree(tree, test_data[num])
pt = point.tolist()
addr = train.index(pt)
if train_label[addr] == test_label[num]:
right += 1
accuracy = right/nums
return accuracy
train_data, test_data, train_label, test_label = loadDataSet('E:\ProgramData\Python3Project\Digit Recognizer/train.csv')
# print('train_data: ', train_data)
tree = createTree(train_data)
print('Create Tree Done!')
# print(tree)
time_createTree = time.time()
print('create tree time: ', time_createTree - time_start)
# train_data = train_data.tolist()
rate = kNN_kdTree(train_data, test_data, train_label, test_label, tree)
print('accuracy: ', rate)
time_end = time.time()
print('time: ', time_end - time_start)