KD树的python实现

结点类型

class Kd_node:
    value = [] #节点值
    deep = None #节点深度
    feature = None #划分标志
    left = None #左子树
    right = None  # 右子树
    parent = None #父节点

1.建立kd树

def Train(x):
    """
    训练模型,输入x,y来训练
    :param x: 训练集x
    """
    #1.找到划分特征
    kd_var = np.var(x,axis=0)             #计算特征值
    max_feature_index = np.argmax(kd_var) #选出最大值的索引
    node = Kd_node()
    #2.递归分割数组
    node = BuildKdTree(node,x,max_feature_index,0)
    return node
 def BuildKdTree(kdnode,x,index,deep):
    """
    递归建树
    :param node:
    :param x:
    :deep : 深度
    :return:
    """
    #1对数据进行按照第几列排序
    x = np.array(x)
    x = x[np.lexsort(x[:, ::-(index + 1)].T)]
    #按照中间值进行节点赋值
    x_number = np.size(x,0) #计算有多少行
    x_lie = np.size(x, 1)  # 计算有多少行
    x_midd = x_number//2    #计算中间的值
    kdnode = Kd_node()
    kdnode.value = x[x_midd,:]
    kdnode.deep = deep
    kdnode.feature = index
    print('kdnode.value:',kdnode.value)
    print('kdnode.feature:', kdnode.feature)
    print('kdnode.deep:', kdnode.deep)
    #数据划分
    x_left = x[0:x_midd,:]
    x_right = x[x_midd+1:,:]
    if x_number == 1: #一个元素直接赋值即可
        return kdnode
    elif x_number == 2:
        kdnode.left = BuildKdTree(kdnode.left, x_left, (index + 1) % x_lie, deep + 1)
        kdnode.left.parent = kdnode
        return kdnode
    else:
        kdnode.left = BuildKdTree(kdnode.left, x_left, (index + 1) % x_lie, deep + 1)
        kdnode.left.parent = kdnode
        kdnode.right = BuildKdTree(kdnode.right, x_right, (index + 1) % x_lie, deep + 1)
        kdnode.right.parent = kdnode
        return kdnode

3.搜索kd树

def search(node, x):  
    global nearestPoint
    global nearestValue
    nearestPoint = None  
    nearestValue = 0  
    def travel(node, depth=0):  
        global nearestPoint
        global nearestValue
        if node != None:  
            n = len(x)  
            axis = depth % n  
            if x[axis] < node.value[axis]:  
                travel(node.left, depth + 1)
            else:
                travel(node.right, depth + 1)
            distNodeAndX = dist(x, node.value)  
            if (nearestPoint is None): 
                nearestPoint = node.value
                nearestValue = distNodeAndX
            elif (nearestValue > distNodeAndX):
                nearestPoint = node.value
                nearestValue = distNodeAndX
            if (abs(x[axis] - node.value[axis]) <= nearestValue):  
                if x[axis] < node.value[axis]:
                    travel(node.right, depth + 1)
                else:
                    travel(node.left, depth + 1)
    travel(node)
    return nearestPoint

def dist(x1, x2): 
    return ((np.array(x1) - np.array(x2)) ** 2).sum() ** 0.5

调用代码

x = [[2,3],[5,4],[9,6],[4,7],[8,1],[7,2]]
node = kd.Train(x)
input = [2.1,3.1]
print(kd.search(node,input))
发布了30 篇原创文章 · 获赞 62 · 访问量 3082

猜你喜欢

转载自blog.csdn.net/weixin_43981664/article/details/104284292