kd树的创建和求最近邻

 1 import numpy as np
 2 arr = np.array([[2, 3], [5, 4], [9, 6], [4, 7], [8, 1], [7, 2]])
 3 arr.shape
 4 
 5 class KDTree():
 6     def __init__(self):
 7         self.value = None
 8         self.left = None
 9         self.right = None
10         self.axis = None
11 
12 def create(arr, k, h=0):
13     if arr.shape[0] == 0:
14         return None
15     tree = KDTree()
16     axis = h % k
17     
18     if arr.shape[0] == 1:
19         tree.value = arr[0]
20         tree.left = None
21         tree.right = None
22         tree.axis = axis
23     else:
24         arr = sorted(arr, key = lambda x:x[axis])
25         arr = np.array(arr)
26         i = arr.shape[0]//2
27         
28         tree.value =  arr[i]
29         tree.left = create(arr[0:i], k, h+1)
30         tree.right = create(arr[i+1:], k, h+1)
31         tree.axis = axis
32     return tree
33 
34 k = KDTree()
35 
36 k = create(arr, arr.shape[1])
37 
38 def preOrder(k):
39     print('当前节点:' + str(k.value))
40     
41     if k.left:
42         preOrder(k.left)
43     if k.right:
44         preOrder(k.right)
45 
46 preOrder(k)
47 
48 def dis(a, b):
49     return np.linalg.norm(a-b)
50 def search(kd, goal, k, h=0):
51     '''输入:kd树,目标点、特征维度k以及当前深度h'''
52     '''输出:在kd树上的与目标点距离(欧氏距离)最近的距离'''
53     if kd.left == None and kd.right == None:
54         return dis(goal, kd.value)
55     if kd.left == None:
56         return min(search(kd.right, goal, k, h+1), dis(kd.value, goal))
57     if kd.right == None:
58         return min(search(kd.left, goal, k, h+1), dis(kd.value, goal))
59     axis = h%k
60     
61     if goal[axis] < kd.value[axis]:
62         cur_dis = search(kd.left, goal, k, h+1)
63     else:
64         cur_dis = search(kd.right, goal, k, h+1)
65     
66     
67     if cur_dis < kd.value[axis]-goal[axis]:////cut  取绝对值
68         return cur_dis;
69     else:
70         if goal[axis] < kd.value[axis]:
71             cur_dis = min(search(kd.right, goal, k, h+1), cur_dis, dis(kd.value, goal))
72         else:
73             cur_dis = min(search(kd.left, goal, k, h+1), cur_dis, dis(kd.value, goal))
74     return cur_dis
75 
76 search(k, np.array([9, 6]), 2)

猜你喜欢

转载自www.cnblogs.com/liuwenhan/p/11723354.html