KD树学习笔记(只适合OIer)

先思考一个问题:

  • 在K维空间里面有许多的点,对于某些给定的点,我们需要找到和它最近的m个点。
  • 这里的距离指的是欧几里得距离:
  • D(p,q)=D(q,p)=sqrt((q1-p1)^2+(q2-p2)^2+(q3-p3)^2+...+ (qn-pn)^2),请你帮忙解决一下。

       输入:

  • 点数n(1≤n≤50000)和维度数k(1≤k≤5)。
  • 接下来的n行,每行k个整数,代表一个点的坐标。
  •  接下来一个正整数:给定的询问数量t(1≤t≤10000)
  • 下面2*t行:
  • 第一行k个整数,表示要查询的点的坐标
  • 第二行一个整数m,表示查询最近的m个点(1≤m≤10)
  •  所有坐标的绝对值不超过10000。
  • 有多组数据!

  输出:

  • 对于每个询问,输出m+1行:
  • 第一行:"the closest m points are:" m为查询中的m
  •  接下来m行每行代表一个点,按照从近到远排序。
  •  保证方案唯一,下面这种情况不会出现:
  • 2 2
  • 1 1
  • 3 3
  • 1
  • 2 2
  • 1

我们知道在二维的情况下我们可以用树状数组来解决(乱搞)。但此时题目中给出了一个会变化的维度,再用树状数组就会提高大量的思维难度(反正我是想象不出5维空间的),

此时我们就需要一种对应多维度的数据结构——KD树。

  1.  KD树的定义:

  Kd-树是K-dimension tree的缩写,是对数据点在k维空间(如二维(x,y),三维(x,y,z),k维(x1,y,z..))中划分的一种数据结构,主要应用于多维空间关键数据的搜索(如:范围搜索和最近邻搜索)。本质上说,Kd-树就是一种平衡二叉树

     首先必须搞清楚的是,k-d树是一种空间划分树,说白了,就是把整个空间划分为特定的几个部分,然后在特定空间的部分内进行相关搜索操作。想像一个三维(多维有点为难我的想象力了)空间,kd树按照一定的划分规则把这个三维空间划分了多个空间,如下图:

 更加易懂的说法是KD树实际上就是多关键字搜索(我蒟蒻只需要知道这个就够了)。

 2. KD树的构建

KD树与线段树的构建相似,需要动态递归建立

void build(int &k,int l,int r,int dir)
{
    int mid=(l+r)>>1;
    k=mid;D=dir;
    nth_element(a+l,a+mid,a+r+1,cmp);
    for(int i=0;i<K;i++)
        a[k].mi[i]=a[k].mx[i]=a[k].d[i];
    if(l<mid)build(a[k].l,l,mid-1,(dir+1)%K);
    if(r>mid)build(a[k].r,mid+1,r,(dir+1)%K);
    pushup(k);
}    

 

  3.KD树的插入

  虽然此题不用插入但我们还是要学的啊

void insert(int k,int dir)
{
    if (q[dir]<a[k].d[dir])
    {
        if (a[k].l) insert(a[k].l,(dir+1)%d);
        else
        {
            a[k].l=++n;
            for(int i=0;i<K;i++)
                a[n].mi[i]=a[n].mx[i]=a[n].d[i]=q[i];
        }
    }
    else
    {
        if (a[k].r) insert(a[k].r,(dir+1)%d);
        else
        {
            a[k].r=++n;
              for(int i=0;i<K;i++)
                a[n].mi[i]=a[n].mx[i]=a[n].d[i]=q[i];
        }
    }
    pushup(k);//同时向上维护
}       

  虽然一棵刚建好的KD树深度是O(log)的。但随便乱插会对时间有巨大的负担很容易TLE。所以我们可以用替罪羊树优化……因为博主太弱还不会(QAQ)所以请同学们自己去学习吧(学会了记得回来给我讲讲啊)!

  4.KD树的查询

  KD树的关键。还记得我们维护的mi[]和mx[],现在我们要用它来做估计了。我们都知道估计可以省下大量的计算,所以这也是KD树独特的地方。但我们的答案不能是估计啊!所以精确的也不能少(QAQ)

 

 
 
long long Guess(int k) //估算与k点的距离值
{ 
    long long i,s=0;
    for(i=0;i<K;i++)
    {
        if(q[i]<a[k].mi[i])s+=(long long)(q[i]-a[k].mi[i])*(q[i]-a[k].mi[i]);
        if(q[i]>a[k].mx[i])s+=(long long)(q[i]-a[k].mx[i])*(q[i]-a[k].mx[i]);
    }
    return s;
}
 
  
 
 
long long Dis(int k) //求查询点与k点的距离值
{ 
    long long i,ans=0;
    for(i=0;i<K;i++)ans+=(long long)(q[i]-a[k].d[i])*(q[i]-a[k].d[i]);
    return ans;
}
 
 

  

 
  
void Query(int x) 
{ 
    if(!x)return; 
    long long dis=Dis(x),dl=Guess(a[x].l),dr=Guess(a[x].r); 
    if(dis<Q.top().first)//为本题需要而建的大根堆
    {
        Q.pop();
        Q.push(make_pair(dis,x));
    }
     if(dl<dr)
    { 
         if(dl<Q.top().first)Query(a[x].l); 
         if(dr<Q.top().first)Query(a[x].r); 
     } 
    else
    { 
        if(dr<Q.top().first)Query(a[x].r);
          if(dl<Q.top().first)Query(a[x].l);
    } 
}
 
   
  
 
 

  

  

 

 

 

原题代码:

#include<bits/stdc++.h>
#define INF 0x3f3f3f3
using namespace std;
typedef pair<long long,int>pii; 
priority_queue<pii>Q; 
struct data{int d[6],mx[6],mi[6],l,r;}a[100005<<1]; 
int q[6],i,j,k,m,n,rt,D,K,t; 
bool cmp(data x,data y){return x.d[D]<y.d[D];} 
void pushup(int x)
{ 
    int i,ls=a[x].l,rs=a[x].r; 
    for(i=0;i<K;i++) 
    {
    if(ls) 
        { 
            a[x].mx[i]=max(a[x].mx[i],a[ls].mx[i]); 
            a[x].mi[i]=min(a[x].mi[i],a[ls].mi[i]); 
        }
    if(rs) 
        { 
            a[x].mx[i]=max(a[x].mx[i],a[rs].mx[i]); 
            a[x].mi[i]=min(a[x].mi[i],a[rs].mi[i]); 
        } 
    } 
}
void build(int &k,int l,int r,int dir)
{
    int mid=(l+r)>>1;
    k=mid;D=dir;
    nth_element(a+l,a+mid,a+r+1,cmp);
    for(int i=0;i<K;i++)
        a[k].mi[i]=a[k].mx[i]=a[k].d[i];
    if(l<mid)build(a[k].l,l,mid-1,(dir+1)%K);
    if(r>mid)build(a[k].r,mid+1,r,(dir+1)%K);
    pushup(k);
}    
long long Guess(int x) //估算与x点的距离值
{ 
    long long i,s=0;
    for(i=0;i<K;i++)
    {
        if(q[i]<a[x].mi[i])s+=(long long)(q[i]-a[x].mi[i])*(q[i]-a[x].mi[i]);
        if(q[i]>a[x].mx[i])s+=(long long)(q[i]-a[x].mx[i])*(q[i]-a[x].mx[i]);
    }
    return s;
}
long long Dis(int x) //求查询点与x点的距离值
{ 
    long long i,ans=0;
    for(i=0;i<K;i++)ans+=(long long)(q[i]-a[x].d[i])*(q[i]-a[x].d[i]);
    return ans;
}
void Query(int x) 
{ 
    if(!x)return; 
    long long dis=Dis(x),dl=Guess(a[x].l),dr=Guess(a[x].r); 
    if(dis<Q.top().first)
    {
        Q.pop();
        Q.push(make_pair(dis,x));
    }
     if(dl<dr)
    { 
         if(dl<Q.top().first)Query(a[x].l); 
         if(dr<Q.top().first)Query(a[x].r); 
     } 
    else
    { 
        if(dr<Q.top().first)Query(a[x].r);
          if(dl<Q.top().first)Query(a[x].l);
    } 
}
void print() //从小到大输出m个点
{ 
    int i,x;
    while(!Q.empty())
    {
        x=Q.top().second;Q.pop();
        print();
        for(i=0;i<K;i++)printf("%d ",a[x].d[i]);
        printf("\n");
    }
}
int main() 
{ 
    while(scanf("%d%d",&n,&K)!=EOF) 
    { 
        memset(a,0,sizeof(a)); 
        while(!Q.empty())Q.pop();//清空堆
        for(i=1;i<=n;i++) //读入n个点的坐标 
            for(j=0;j<K;j++)
                scanf("%d",&a[i].d[j]); 
        build(rt,1,n,0);
        scanf("%d",&t); //建立KD树 scanf("%d",&t); 
        for(i=1;i<=t;i++) //t组询问 
        { 
            for(j=0;j<K;j++)scanf("%d",&q[j]);//读入查询点的坐标 
            scanf("%d",&m); 
            for(j=1;j<=m;j++)Q.push(make_pair(INF,0));//把k个INF加入大根堆 
            Query(rt); 
            printf("the closest %d points are:\n",m); 
            print(); 
        } 
    }
    return 0; 
}

 

  

猜你喜欢

转载自www.cnblogs.com/wyb-----520/p/10162589.html