KD-Tree 初学(模板+HDU 4347)

        KD_Tree的话,一年半以前,那时候翻我们学校给的模板翻到看了,上面写的"by edward_mj"。这个ID浙大的人应该很熟悉,13-14年连续两年带队打进WF。这里我可以吹一下,edward_mj可是和我同一个高中毕业的哟~当时也问了一下学长,学长说KD树是他写的最溜的一个数据结构之一。

        直到这几天我才有时间来学学KD树。所谓KD树,顾名思义k-dimension tree,能够存储k维数据的一棵树。这里这个树其实是指BST。朴素的BST,像Treap、Splay这些,节点都是存储一维的信息,但是KD树就是它们的拓展,可以存储多维数据。一维的时候很容易做到,直接把数据大小当作key,小的放在左边大的放在右边即可。其实拓展到高维的时候也是类似的,只不过要考虑多个维度,最显而易见的想法就是对于每一层,我用不同的维度的数值作为key来进行比较。例如说,对于一个二维数据,我第一层用第一维,第二层用第二维,第三层第一维,第四层第二维……这样,一棵树可以比较均匀的建立。下图给出了一个建立好的KD树的案例。

                         

        可以看到,这个案例中总共有6个点,画在坐标轴上就如图1所示,然后我们对图1的区域进行划分。按照顺序,显示竖着切(根据横坐标x)再是横着切(根据纵坐标y)。(7,2)作为根节点,竖着切一刀,所有横坐标小于7的在它的左边,大于等于的在它的右边。接着,下一层换作另外一个坐标继续分割,直到所有的点被划分完毕。

        KD树的建立就是这样,与普通的BST相差不大。但他的主要功能却与BST不同,一般来说,KD树用来求与已知点距离最近的m个点是哪几个,也即KNN问题k邻近查询问题。给你一个点,第一件要做的事情显然就是往他所在的区域进行查找,所以对于给定点按照其数值在树上遍历,一直找到一个最小的区域再开始回溯。这里我还是用图结合上上面那个案例来说明。


        对于给定的查询点(2,4.5),我们首先要确定他所在的最小的区域,显然就是左图左上角那一个矩形的区域。然后,目前为止我们能够直到的距离查询点最近的点是点(4,7),于是我们以这两点之间的距离为半径,查询点为圆心做一个圆(如果高维那就做超球体),判断这个这个最近点的父亲是否在圆内或者圆上。如果在,那么说明这个它的兄弟所在的区域里面可能会有比它距离查询点更近的点,所以要对它的兄弟也进行搜索。这样,把所有该搜索的地方都搜索完毕后,剩下的点就是最邻近的点。如果是求K邻近,那么只需要弄一个优先队列,里面始终保证有且仅有K个点即可。

        KD树的话,初学就暂时学了这么多。代码实现,见下面我给的模板题的代码吧。模板题是 HDU 4347 。题意就是一个m邻近问题。代码如下:

#include<bits/stdc++.h>
#define sq(x) (x)*(x)
#define N (55555)

using namespace std;

int idx,k,n,m,q;

struct Node
{
    int x[5];
    bool operator < (const Node &u) const
    {
        return x[idx] < u.x[idx];
    }
} P[N];

typedef pair<double,Node> PDN;
priority_queue<PDN> que;

struct KD_Tree
{
    int sz[N<<2]; Node p[N<<2];

    void build(int i,int l,int r,int dep)
    {
        if (l>r) return;
        int mid=(l+r)>>1;
        idx=dep%k;sz[i]=r-l;
        sz[i<<1]=sz[i<<1|1]=-1;
        nth_element(P+l,P+mid,P+r+1);
        p[i]=P[mid];
        build(i<<1,l,mid-1,dep+1);
        build(i<<1|1,mid+1,r,dep+1);
    }

    void query(int i,int m,int dep,Node a)
    {
        if (sz[i]==-1) return;
        PDN tmp=PDN(0,p[i]);
        for(int j=0;j<k;j++)
            tmp.first+=sq(tmp.second.x[j]-a.x[j]);
        int lc=i<<1,rc=i<<1|1,dim=dep%k,flag=0;
        if (a.x[dim]>=p[i].x[dim]) swap(lc,rc);
        if (~sz[lc]) query(lc,m,dep+1,a);
        if (que.size()<m) que.push(tmp),flag=1;
        else
        {
            if (tmp.first<que.top().first) que.pop(),que.push(tmp);
            if (sq(a.x[dim]-p[i].x[dim])<que.top().first) flag=1;
        }
        if (~sz[rc]&&flag) query(rc,m,dep+1,a);
    }
} KDT;

int main()
{
    while(~scanf("%d%d",&n,&k))
    {
        for(int i=0;i<n;i++)
            for(int j=0;j<k;j++)
                scanf("%d",&P[i].x[j]);
        KDT.build(1,0,n-1,0);
        scanf("%d",&q);
        while(q--)
        {
            Node now;
            for(int i=0;i<k;i++)
                scanf("%d",&now.x[i]);
            scanf("%d",&m); int t=0;
            KDT.query(1,m,0,now); Node pp[21];
            for(;!que.empty();que.pop())
                pp[++t]=que.top().second;
            printf("the closest %d points are:\n",t);
            for(int i=m;i>0;i--)
            {
                printf("%d",pp[i].x[0]);
                for(int j=1;j<k;j++)
                    printf(" %d",pp[i].x[j]);
                puts("");
            }
        }
    }

    return 0;

}

猜你喜欢

转载自blog.csdn.net/u013534123/article/details/80952174