kd树 寻找k近邻算法 python实现

kd树和寻找k近邻算法具体流程就不赘述了,这个链接写的很好懂 https://zhuanlan.zhihu.com/p/23966698

 按照链接里的算法写了k近邻的python实现

from math import sqrt
class KDnode:
    def __init__(self, data, left, right, split):
        self.left = left
        self.right = right
        self.split = split
        self.data = data

class KDtree:
    def __init__(self,data):
        self.k = len(data[0])

        def CreatKD(split, data_set):
            if not data_set:
                return None
            data_set.sort(key=lambda x: x[split])
            flag = len(data_set)//2
            new_split = (split+1) % self.k
            return KDnode(data_set[flag], CreatKD(new_split,data_set[:flag]), CreatKD(new_split, data_set[flag+1:]), split)

        self.root = CreatKD(0, data)

def nearest(tree, point, k):
    L = []

    def dis(x,p):  #x是当前节点,p是目标节点
        if len(L)<k:
            d = sqrt(sum((x1-x2)**2 for x1, x2 in zip(x, p)))
            L.append([x,d])
            return
        else:
            d = sqrt(sum((x1 - x2) ** 2 for x1, x2 in zip(x, p)))
            L.sort(key=lambda a: a[1])
            if(L[-1][1] > d):
                L.pop()
                L.append([x, d])
            return

    def travel(kd_node):
        if kd_node is None:
            return

        s = kd_node.split
        if kd_node.data[s] > point[s]:
            nearnode = kd_node.left
            furthnode = kd_node.right
        else:
            nearnode = kd_node.right
            furthnode = kd_node.left
        travel(nearnode)
        dis(kd_node.data, point)
        dis1 = abs(kd_node.data[s] - point[s])
        dis2 = sqrt(sum((x1 - x2) ** 2 for x1, x2 in zip(kd_node.data, point)))  #最长距离
        if len(L) < k or dis1 < dis2:
            travel(furthnode)
        else:
            return
    travel(tree.root)
    return L

# data = [[2,3],[5,4],[9,6],[4,7],[8,1],[7,2]]
# kd = KDtree(data)
# ret = nearest(kd, [3,4.5], 2)
# print(ret)

from time import clock
from random import random

# 产生一个k维随机向量,每维分量值在0~1之间
def random_point(k):
    return [random() for _ in range(k)]

# 产生n个k维随机向量
def random_points(k, n):
    return [random_point(k) for _ in range(n)]
N = 400000
t0 = clock()
kd2 = KDtree(random_points(3, N))  # 构建包含四十万个3维空间样本点的kd树
t1= clock()
print("time: ", t1-t0, "s")
ret2 = nearest(kd2, [0.1,0.5,0.8], 2)      # 四十万个样本点中寻找离目标最近的点
t2 = clock()
print("time: ", t2-t1, "s")
print(ret2)

运行结果如下: 

发布了61 篇原创文章 · 获赞 15 · 访问量 2万+

猜你喜欢

转载自blog.csdn.net/surserrr/article/details/98884472