## KD树的python实现

``````class Kd_node:
value = [] #节点值
deep = None #节点深度
feature = None #划分标志
left = None #左子树
right = None  # 右子树
parent = None #父节点``````

1.建立kd树

``````def Train(x):
"""
训练模型，输入x，y来训练
:param x: 训练集x
"""
#1.找到划分特征
kd_var = np.var(x,axis=0)             #计算特征值
max_feature_index = np.argmax(kd_var) #选出最大值的索引
node = Kd_node()
#2.递归分割数组
node = BuildKdTree(node,x,max_feature_index,0)
return node
def BuildKdTree(kdnode,x,index,deep):
"""
递归建树
:param node:
:param x:
:deep : 深度
:return:
"""
#1对数据进行按照第几列排序
x = np.array(x)
x = x[np.lexsort(x[:, ::-(index + 1)].T)]
#按照中间值进行节点赋值
x_number = np.size(x,0) #计算有多少行
x_lie = np.size(x, 1)  # 计算有多少行
x_midd = x_number//2    #计算中间的值
kdnode = Kd_node()
kdnode.value = x[x_midd,:]
kdnode.deep = deep
kdnode.feature = index
print('kdnode.value:',kdnode.value)
print('kdnode.feature:', kdnode.feature)
print('kdnode.deep:', kdnode.deep)
#数据划分
x_left = x[0:x_midd,:]
x_right = x[x_midd+1:,:]
if x_number == 1: #一个元素直接赋值即可
return kdnode
elif x_number == 2:
kdnode.left = BuildKdTree(kdnode.left, x_left, (index + 1) % x_lie, deep + 1)
kdnode.left.parent = kdnode
return kdnode
else:
kdnode.left = BuildKdTree(kdnode.left, x_left, (index + 1) % x_lie, deep + 1)
kdnode.left.parent = kdnode
kdnode.right = BuildKdTree(kdnode.right, x_right, (index + 1) % x_lie, deep + 1)
kdnode.right.parent = kdnode
return kdnode``````

3.搜索kd树

``````def search(node, x):
global nearestPoint
global nearestValue
nearestPoint = None
nearestValue = 0
def travel(node, depth=0):
global nearestPoint
global nearestValue
if node != None:
n = len(x)
axis = depth % n
if x[axis] < node.value[axis]:
travel(node.left, depth + 1)
else:
travel(node.right, depth + 1)
distNodeAndX = dist(x, node.value)
if (nearestPoint is None):
nearestPoint = node.value
nearestValue = distNodeAndX
elif (nearestValue > distNodeAndX):
nearestPoint = node.value
nearestValue = distNodeAndX
if (abs(x[axis] - node.value[axis]) <= nearestValue):
if x[axis] < node.value[axis]:
travel(node.right, depth + 1)
else:
travel(node.left, depth + 1)
travel(node)
return nearestPoint

def dist(x1, x2):
return ((np.array(x1) - np.array(x2)) ** 2).sum() ** 0.5``````

``````x = [[2,3],[5,4],[9,6],[4,7],[8,1],[7,2]]
node = kd.Train(x)
input = [2.1,3.1]
print(kd.search(node,input))``````