K-D Tree是K临近算法中的一种
k-d 是 k-dimensional 的缩写,也就是k维树,换句话说就是说这个树是维护一个k维元素的树,这个树上的节点有k个分量。一个k维的二叉搜索树就是k-d tree了。
先了解一下K-D Tree的概念
例子(网图)
K == 3的K-D Tree模型
首先来看下树的组织原则。将每一个元组按0排序(第一项序号为0,第二项序号为1,第三项序号为2),在树的第n层,第 n%3 项被用粗体显示,而这些被粗体显示的树就是作为二叉搜索树的key值,比如,根节点的左子树中的每一个节点的第一个项均小于根节点的的第一项,右子树的节点中第一项均大于根节点的第一项,子树依次类推。
对于这样的一棵树,对其进行搜索节点会非常容易,给定一个元组,首先和根节点比较第一项,小于往左,大于往右,第二层比较第二项,依次类推。
如此,我们知道了建树的原则,但是这样建树是为了什么呢?
看到一个一维K == 1的K-D Tree的模型,我们将其投影到坐标轴上来看。
这样,就是一个二叉搜索树的模型了,我们想找一个点值,就可以利用二叉搜索树BST的原则来进行搜索了。
于是,我们现在要将这样的一维拓展到二维,那么,原来的用点来分割,到了二维之后就是用线来将其分割开了。
利用垂直于坐标轴的直线,将图中点分割
第一次:黄色点
第二次:红色点
第三次:绿色点
第四次:蓝色点
结束算法。
这样以来,我们将原图分割成了多个小块的样式,但是却能保证这个K == 2的二维空间能够利用BST的二分搜索树进行检索了。
同理,在K==3的时候,我们再看到之前给出的样例,就能明白为什么是这样进行分割了。
建树
建树的思想是基于二分搜索树的,这里的做法类似于CDQ分治。
问题来了,既然每个元素有k个分量,我们以哪一个分量作为标准去划分左右子树呢?
对每个分量依次轮流划分当然是可以的。
更好的是每次对跨度最大的那一个分量进行划分。
怎么理解?
我们假设现在K==2,有n个2维的点(x,y),我们要将其建成一个k-d tree。
如果每个分量轮流作为度量那就是以下情况:
现在这些点都在一个矩形中,我们先将这个正方形竖着切一刀,这样相当于对x划分了一下,左边的点都进左子树,右边的点都进右子树。
然后对于左边的矩形和右边的矩形我们再分别横着切一刀,将上下两边的点分别进左右子树。
以此类推……
而每次选跨度大的分量划分则是每次看所要划分的矩形究竟是水平的长度长还是垂直的长度长。那么大家就可以脑补了。
现在另一个问题来了,确定了我们每次划分的维度之后,我们究竟以哪一个结点作为划分呢?(相当于究竟在矩形什么位置切一刀?)
我们当然是希望左右子树的节点数越接近越好,所以答案就显而易见了,我们只要在中位数的位置划分左右子树就行了。
void build(int rt, int l, int r)
{
if(l > r) return;
op = 0; key[rt] = 0;
for(int i=0; i<K; i++)
{
double ave = 0.;
var[i] = 0.;
for(int j=l; j<=r; j++) ave += a[j].d[i];
ave /= (r - l + 1.);
for(int j=l; j<=r; j++) var[i] += (ave - a[j].d[i]) * (ave - a[j].d[i]);
var[i] /= (r - l + 1.);
if(var[i] > var[key[rt]])
{
key[rt] = i;
op = i;
}
}
int mid = HalF;
nth_element(a + l, a + mid, a + r + 1, cmp);
tree[rt] = a[mid];
build(rt << 1, l, mid - 1); build(Rson);
}
因此,我的建树的代码就是这样的。
其中用到的STL:
C++ nth_element(STL nth_element)排序算法
应用的范围由它的第一个和第三个参数指定。第二个参数(作为中间参数)是一个指向第 n 个元素的迭代器。如果这个范围内的元素是完全有序的,nth_dement() 的执行会导致第 n 个元素被放置在适当的位置。这个范围内,在第 n 个元素之前的元素都小于第 n 个元素,而且它后面的每个元素都会比它大。算法默认用 “ < ” 运算符来生成这个结果。
这是一个例子(元素集:22, 7, 93, 45, 19, 56, 88, 12, 8, 7, 15, 10)
查询
k-d tree常常用来处理的问题是查询k维空间里一个点最邻近的点。
做法就是从根节点出发,判断每个点能否更新当前的最近距离。然后判断所查询的点在左子树还是右子树,进入相应的子树继续更新最近距离。
但是有个问题,查询的点不在的那个区间也有可能有最近的点在,所以我们需要在回溯的时候判断需不需要递归到另一个区间内查询。
判断的条件就是当前维护的最近距离比目标点到分割两个区间的超平面的距离大,或者换句话说就是以目标点为球心,当前最近距离为半径的超球面与被分割出的另一部分相交。
查询操作最好情况是O(logn) 最坏情况是O(n)
void query(int rt, int l, int r)
{
if(l > r) return;
int mid = HalF;
ll dist = dis(q, tree[rt]);
if(q.id ^ tree[rt].id)
{
if(dist < ans)
{
ans = dist;
point = tree[rt];
}
}
int k_key = key[rt];
ll ra = (q.d[k_key] - tree[rt].d[k_key]) * (q.d[k_key] - tree[rt].d[k_key]);
if(q.d[k_key] < tree[rt].d[k_key])
{
query(rt << 1, l, mid - 1);
if(ra <= ans) query(Rson);
}
else
{
query(Rson);
if(ra <= ans) query(rt << 1, l, mid - 1);
}
}
模板题 求二维平面每个点的最近点的欧式距离的平方
#include <iostream>
#include <cstdio>
#include <cmath>
#include <string>
#include <cstring>
#include <algorithm>
#include <limits>
#include <vector>
#include <stack>
#include <queue>
#include <set>
#include <map>
#include <bitset>
//#include <unordered_map>
//#include <unordered_set>
#define lowbit(x) ( x&(-x) )
#define pi 3.141592653589793
#define e 2.718281828459045
#define INF 0x3f3f3f3f3f3f3f3f
#define eps 1e-8
#define HalF (l + r)>>1
#define lsn rt<<1
#define rsn rt<<1|1
#define Lson lsn, l, mid
#define Rson rsn, mid+1, r
#define QL Lson, ql, qr
#define QR Rson, ql, qr
#define myself rt, l, r
#define MP(a, b) make_pair(a, b)
#define MAX_3(a, b, c) max(a, max(b, c))
#define Rabc(x) x > 0 ? x : -x
using namespace std;
typedef unsigned long long ull;
typedef unsigned int uit;
typedef long long ll;
const int K = 2, maxN = 1e5 + 7;
int N;
struct node
{
ll d[K]; int id;
node(ll a=0, ll b=0, int c=0):d{a, b}, id(c) {}
void In() { scanf("%lld%lld", &d[0], &d[1]); }
} tree[maxN << 2], a[maxN], b[maxN];
int op;
inline bool cmp(node e1, node e2) { return e1.d[op] < e2.d[op]; }
node q, point;
ll ans;
int key[maxN << 2];
double var[maxN << 2];
ll dis(node a, node b)
{
ll sum = 0;
for(int i=0; i<K; i++) sum += (a.d[i] - b.d[i]) * (a.d[i] - b.d[i]);
return sum;
}
void build(int rt, int l, int r)
{
if(l > r) return;
op = 0; key[rt] = 0;
for(int i=0; i<K; i++)
{
double ave = 0.;
var[i] = 0.;
for(int j=l; j<=r; j++) ave += a[j].d[i];
ave /= (r - l + 1.);
for(int j=l; j<=r; j++) var[i] += (ave - a[j].d[i]) * (ave - a[j].d[i]);
var[i] /= (r - l + 1.);
if(var[i] > var[key[rt]])
{
key[rt] = i;
op = i;
}
}
int mid = HalF;
nth_element(a + l, a + mid, a + r + 1, cmp);
tree[rt] = a[mid];
build(rt << 1, l, mid - 1); build(Rson);
}
void query(int rt, int l, int r)
{
if(l > r) return;
int mid = HalF;
ll dist = dis(q, tree[rt]);
if(q.id ^ tree[rt].id)
{
if(dist < ans)
{
ans = dist;
point = tree[rt];
}
}
int k_key = key[rt];
ll ra = (q.d[k_key] - tree[rt].d[k_key]) * (q.d[k_key] - tree[rt].d[k_key]);
if(q.d[k_key] < tree[rt].d[k_key])
{
query(rt << 1, l, mid - 1);
if(ra <= ans) query(Rson);
}
else
{
query(Rson);
if(ra <= ans) query(rt << 1, l, mid - 1);
}
}
int main()
{
int T; scanf("%d", &T);
while(T--)
{
scanf("%d", &N);
for(int i=1; i<=N; i++)
{
a[i].In();
a[i].id = i;
b[i] = a[i];
}
build(1, 1, N);
for(int i=1; i<=N; i++)
{
q = b[i];
ans = INF;
query(1, 1, N);
printf("%lld\n", ans);
}
}
return 0;
}