《统计学习方法》第三章中的算法3.2与算法3.3
实现KD树的构造级最邻近点搜索
算法3.2:
构造KD树
输入:空间数据集T
输出:KD树
Python代码:
def creat_ketree(data,depth=0):
"""
创建KD树
axis为坐标轴
num为中位数节点
depth为树的深度
median为父节点
left,right分别为左右孩子节点
:param data:
:param depth:
:return:
"""
try:
m = len(data[0])
except IndexError as e:
return None
tree_node = {}
axis = depth % m
depth += 1
tree_node['split'] = axis
data = sorted(data,key = lambda data : data[axis])
num = len(data)
tree_node['median'] = data[num]
tree_node['left'] = creat_ketree(data[:num] , depth)
tree_node['right'] = creat_ketree(data[num+1:] , depth)
return tree_node
算法3.3:
搜索KD树
输入:已构造好的KD树,目标点x
输出:x的最近邻
Python代码:
def Euclidean_distance(A,B):
"""
计算欧式距离
:param A:
:param B:
:return:
"""
sum_distance = 0
for i in range(len(A)):
sum_distance += pow(abs(A[i] - B[i]) , 2)
return math.sqrt(sum_distance)
def search_tree(tree,data):
"""
首先找到距离实例点data最近的叶子节点
后开始遍历kd树
:param tree:
:param data:
:return:
"""
k = len(data)
if tree is None:
return [0]*k , float('inf')
else:
median_point = tree['median'] #找到距离实例点data最近的叶子节点
node_axis = tree['split']
if data[node_axis] > median_point[node_axis]:
nearest_point , nearest_distance = search_tree(tree['right'],data)
else:
nearest_point , nearest_distance = search_tree(tree['left'] , data)
now_distance = Euclidean_distance(data , median_point) #计算当前节点与实例点data的距离
if nearest_distance > now_distance: #若当前节点距离小于最小距离,更新最小距离和最斤实例点
nearest_point = median_point.copy()
nearest_distance = now_distance
if now_distance < abs(median_point[node_axis] - data[node_axis]): #否则计算到父节点axis轴距离
return nearest_point , nearest_distance
else:
if median_point[node_axis] <= data[node_axis]:
nearer_point , nearer_distance = search_tree(tree['left'] , data)
else:
nearer_point , nearer_distance = search_tree(tree['right'] , data)
if nearer_distance < nearest_distance:
nearest_distance = nearer_distance
nearest_point = nearer_point.copy()
return nearest_point , nearest_distance