KDTree 解释与实现

前言:《统计学习方法》第三章 K 近邻

KDTree 解释

一种数据结构,能快速搜索最近点

KDTree 实现

文字叙述

  • 选择一个维度(x,y,z ......)
  • 选出这些点这个维度值的中位数
  • 将数据按中位数分为两部分
  • 对这两部分数据同样执行上述操作,直到数据点的数目为 1

图片叙述

15548795-514587725ccc0f47.png
步骤
15548795-10acc5c79d490dac.png
步骤
15548795-3e17fd572bc6911b.png
步骤
15548795-6f46da2a3f887323.png
步骤
15548795-1550767e12caa5cc.png
步骤

15548795-177ea583329ee85b.png
最后

更加详细的内容来自这里

代码

class KdNode:
    def __init__(self, axis, point, left, right):
        self.axis = axis
        self.point = point
        self.left = left
        self.right = right


class KdTree(object):
    def __init__(self, data):
        """
        data is like [[1, 2], [3, 4], [5, 6]]
        """
        # 如果点是 二维的 K 值是 2
        # 如果点是 三维的 K 值是 3
        k = len(data[0])
        self.node_num = len(data)
        def create_node(axis, data_set):
            if not data_set:
                return None
            # 当前节点
            data_set.sort(key=lambda x: x[axis])
            point_pos = len(data_set) // 2
            point_media = data_set[point_pos]
            next_axis = (axis + 1) % k
            return KdNode(axis, point_media, 
                        create_node(next_axis, data_set[0:point_pos]),
                        create_node(next_axis, data_set[point_pos+1:]))
        
        self.root = create_node(0, data)

建树的时间复杂度分析

我用的是最简单的策略,每一次寻找中位数的时候,都用一次快排。也就是,每一个节点都会进行一次快排

  • 第一层 O(nlogn)
  • 第二层 O(nlog(n/2))
  • 第三层 O(nlog(n/4))
  • ... ...

树的高度是 logn
所以最后时间复杂度是 O(n(logn)^2)

猜你喜欢

转载自blog.csdn.net/weixin_34185560/article/details/86808310