bryce1010专题——KD-Tree

版权声明:时间是有限的,知识是无限的,那就需要在有限的时间里最大化的获取知识。 https://blog.csdn.net/Fire_to_cheat_/article/details/83421046

bryce1010专题——KD-Tree

【KD-Tree模板题】
HDU-4347 The Closest M Points
http://acm.hdu.edu.cn/showproblem.php?pid=4347
【题意】
求一个点的最近M个点。
(1=<n<=50000,1=<t<=10000)

【思路】
KD-Tree

#include <iostream>
#include <string.h>
#include <algorithm>
#include <stdio.h>
#include <math.h>
#include <queue>
 
using namespace std;
 
#define N 50005
 
#define lson rt << 1
#define rson rt << 1 | 1
#define Pair pair<double, Node>
#define Sqrt2(x) (x) * (x)
 
int n, k, idx;
 
struct Node
{
    int feature[5];     //定义属性数组
    bool operator < (const Node &u) const
    {
        return feature[idx] < u.feature[idx];
    }
}_data[N];   //_data[]数组代表输入的数据
 
priority_queue<Pair> Q;     //队列Q用于存放离p最近的m个数据
 
class KDTree{
 
    public:
        void Build(int, int, int, int);     //建树
        void Query(Node, int, int, int);    //查询
 
    private:
        Node data[4 * N];    //data[]数组代表K-D树的所有节点数据
        int flag[4 * N];      //用于标记某个节点是否存在,1表示存在,-1表示不存在
}kd;
 
//建树步骤,参数dept代表树的深度
void KDTree::Build(int l, int r, int rt, int dept)
{
    if(l > r) return;
    flag[rt] = 1;                   //表示编号为rt的节点存在
    flag[lson] = flag[rson] = -1;   //当前节点的孩子暂时标记不存在
    idx = dept % k;                 //按照编号为idx的属性进行划分
    int mid = (l + r) >> 1;
    nth_element(_data + l, _data + mid, _data + r + 1);   //nth_element()为STL中的函数
    data[rt] = _data[mid];
    Build(l, mid - 1, lson, dept + 1);  //递归左子树
    Build(mid + 1, r, rson, dept + 1);  //递归右子树
}
 
//查询函数,寻找离p最近的m个特征属性
void KDTree::Query(Node p, int m, int rt, int dept)
{
    if(flag[rt] == -1) return;   //不存在的节点不遍历
    Pair cur(0, data[rt]);       //获取当前节点的数据和到p的距离
    for(int i = 0; i < k; i++)
        cur.first += Sqrt2(cur.second.feature[i] - p.feature[i]);
    int dim = dept % k;          //跟建树一样,这样能保证相同节点的dim值不变
    bool fg = 0;                 //用于标记是否需要遍历右子树
    int x = lson;
    int y = rson;
    if(p.feature[dim] >= data[rt].feature[dim]) //数据p的第dim个特征值大于等于当前的数据,则需要进入右子树
        swap(x, y);
    if(~flag[x]) Query(p, m, x, dept + 1);      //如果节点x存在,则进入子树继续遍历
 
    //以下是回溯过程,维护一个优先队列
    if(Q.size() < m)   //如果队列没有满,则继续放入
    {
        Q.push(cur);
        fg = 1;
    }
    else
    {
        if(cur.first < Q.top().first)  //如果找到更小的距离,则用于替换队列Q中最大的距离的数据
        {
            Q.pop();
            Q.push(cur);
        }
        if(Sqrt2(p.feature[dim] - data[rt].feature[dim]) < Q.top().first)
        {
            fg = 1;
        }
    }
    if(~flag[y] && fg) 
        Query(p, m, y, dept + 1);
}
 
//输出结果
void Print(Node data)
{
    for(int i = 0; i < k; i++)
        printf("%d%c", data.feature[i], i == k - 1 ? '\n' : ' ');
}
 
int main()
{
    while(scanf("%d%d", &n, &k)!=EOF)
    {
        for(int i = 0; i < n; i++)
            for(int j = 0; j < k; j++)
                scanf("%d", &_data[i].feature[j]);
        kd.Build(0, n - 1, 1, 0);
        int t, m;
        scanf("%d", &t);
        while(t--)
        {
            Node p;
            for(int i = 0; i < k; i++)
                scanf("%d", &p.feature[i]);
            scanf("%d", &m);
            while(!Q.empty()) Q.pop();   //事先需要清空优先队列
            kd.Query(p, m, 1, 0);
            printf("the closest %d points are:\n", m);
            Node tmp[25];
            for(int i = 0; !Q.empty(); i++)
            {
                tmp[i] = Q.top().second;
                Q.pop();
            }
            for(int i = m - 1; i >= 0; i--)
                Print(tmp[i]);
        }
    }
    return 0;
}

【KD-Tree模板】
HDU2966 In case of failure
http://acm.hdu.edu.cn/showproblem.php?pid=2966

【题意】
给n个二维点,求每个点距离其它点最近的距离。

import java.util.Arrays;
import java.util.Scanner;

public class Main {
    
    final static int SIZE = 100005;
	final static double EPS = 1e-10;
    
    private boolean[] d = null;
    private Node[] p = null;
    private long res;
    private int index;
    private int size;
    
    public class Node{
        private long[] x = null;
        Node(){
            x = new long[2];
        }
    }
    
    Main(int size){
        d = new boolean[size];
        p = new Node[size];
        for(int i = 0; i < size; i++)
            p[i] = new Node();
    }

    public void setSize(int size){
        this.size = size;
		Arrays.fill(d, false);
    }
    
    public void clear(){
        res = Long.MAX_VALUE;
        index = 0;
    }
    
    public void Insert(int id, Node t){
        p[id] = t;
    }
    
    public Node get(int id){
        return p[id];
    }

    public void InsertSort(Node a[], int id, int l, int r){
		for(int i = l + 1; i <= r; i++){
			if(a[i - 1].x[id] > a[i].x[id]){
				Node t = new Node();
				t = a[i];
				int j = i;
				while(j > l && a[j - 1].x[id] > t.x[id])
	            {
	                a[j] = a[j - 1];
	                j--;
	            }
	            a[j] = t;
			}
		}
	}
	
	public Node FindMid(Node a[], int id, int l, int r)
	{
	    if(l == r) return a[l];
	    int i = 0;
	    int n = 0;
	    for(i = l; i < r - 5; i += 5)
	    {
	        InsertSort(a, id, i, i + 4);
	        n = i - l;

			Node t = new Node();
			t = a[l + n / 5];
			a[l + n / 5] = a[i + 2];
			a[i + 2] = t;
	    }

	    int num = r - i + 1;
	    if(num > 0)
	    {
	        InsertSort(a, id, i, i + num - 1);
	        n = i - l;

			Node t = new Node();
			t = a[l + n / 5];
            a[l + n / 5] = a[i + num / 2];
			a[i + num / 2] = t;
	    }
	    n /= 5;
	    if(n == l) return a[l];
	    return FindMid(a, id, l, l + n);
	}
	
	public boolean Equals(Node a, Node b){
		if(Math.abs(a.x[0] - b.x[0]) > EPS) 
			return false;
		if(Math.abs(a.x[1] - b.x[1]) > EPS) 
			return false;
		return true;
	}
	
	public int FindId(Node a[], int l, int r, Node num)
	{
	    for(int i = l; i <= r; i++)
	        if(Equals(a[i], num))
	        	return i;
	    return -1;
	}
	
	public int Partion(Node a[], int id, int l, int r, int p)
	{
	    Node t = new Node();
        t = a[p];
		a[p] = a[l];
		a[l] = t;

	    int i = l;
	    int j = r;
	    Node pivot = a[l];
	    while(i < j)
	    {
	        while(a[j].x[id] >= pivot.x[id] && i < j)
	            j--;
	        a[i] = a[j];
	        while(a[i].x[id] <= pivot.x[id] && i < j)
	            i++;
	        a[j] = a[i];
	    }
	    a[i] = pivot;
	    return i;
	}
	
	public Node BFPTR(Node a[], int id, int l, int r, int k)
	{
		if(l > r) return null;
	    Node num = FindMid(a, id, l, r);  
	    int p =  FindId(a, l, r, num); 
	    int i = Partion(a, id, l, r, p);

	    int m = i - l + 1;
	    if(m == k) return a[i];
	    if(m > k)  return BFPTR(a, id, l, i - 1, k);
	    return BFPTR(a, id, i + 1, r, k - m);
	}

    public Node getInterval(Node p[], int id, int l, int r){
        Node t = new Node();
        long max = Long.MIN_VALUE;
        long min = Long.MAX_VALUE;
        for(int i = l; i <= r; i++){
            if(max < p[i].x[id]) max = p[i].x[id];
            if(min > p[i].x[id]) min = p[i].x[id];
        }
        t.x[0] = min;
        t.x[1] = max;
        return t;
    }
    
    public long getDist(Node a, Node b){
        return (a.x[0] - b.x[0]) * (a.x[0] - b.x[0]) + (a.x[1] - b.x[1]) * (a.x[1] - b.x[1]);
    }

    public void Build(Node p[], int l, int r){
        if(l > r) return;
        Node t1 = getInterval(p, 0, l, r);
        long minx = t1.x[0];
        long maxx = t1.x[1];
        
        Node t2 = getInterval(p, 1, l, r);
        long miny = t2.x[0];
        long maxy = t2.x[1];
        
        int mid = (l + r) >> 1;
        d[mid] = (maxx - minx > maxy - miny);
        
        BFPTR(p, d[mid] ? 0 : 1, l, r, mid - l + 1);
        
        Build(p, l, mid - 1);
        Build(p, mid + 1, r);
    }
    
    public void Find(Node p[], Node t, int l, int r){
        if(l > r) return;
        int mid = (l + r) >> 1;
        
        long dist = getDist(p[mid], t);
        long df = d[mid] ? (t.x[0] - p[mid].x[0]) : (t.x[1] - p[mid].x[1]);
        
        if(dist > 0 && dist < res){
            res = dist;
            index = mid;
        }
        
        int l1 = l;
        int r1 = mid - 1;
        int l2 = mid + 1;
        int r2 = r;
        if (df > 0){

            l1 ^= l2;
            l2 ^= l1;
            l1 ^= l2;
            
            r1 ^= r2;
            r2 ^= r1;
            r1 ^= r2;
        }
        Find(p, t, l1, r1);
        if (df * df < res) Find(p, t, l2, r2);
    }
    
    public void Build(){
        Build(p, 0, size - 1);
    }
    
    public int Search(Node t){
        clear();
        Find(p, t, 0, size - 1);
        return index;
    }
    
    public static void main(String[] args){

        Scanner cin = new Scanner(System.in);
        int t = cin.nextInt();
        Main kd = new Main(SIZE);

		Node[] node = new Node[SIZE];
		for(int i = 0; i < SIZE; i++){
			 node[i] = kd.new Node();
		}

        while(t-- > 0){
            int n = cin.nextInt();
            kd.setSize(n);
            for(int i = 0; i < n; i++){
                node[i].x[0] = cin.nextLong();
                node[i].x[1] = cin.nextLong();
                kd.Insert(i, node[i]);
            }

            kd.Build();
            for(int i = 0; i < n; i++){
                int id = kd.Search(node[i]);
                System.out.println(kd.getDist(kd.get(id), node[i]));
            }
        }
        
    }
}

【KD-Tree带花费限制】
HDU-5922 Finding Hotels
https://vjudge.net/problem/HDU-5992

【题意】
给出N个旅店的二维坐标和价格,给出M个顾客的坐标和可以接受的价格,求每个顾客在可接受价格的条件内能找到的最近的旅店。
(N<=200000,M<=20000)

#include<iostream>
#include<cstdlib>
#include<cstdio>
#include<string>
#include<vector>
#include<deque>
#include<queue>
#include<algorithm>
#include<set>
#include<map>
#include<stack>
#include<ctime>
#include <string.h>
#include<math.h>

using namespace std;
#define ll long long
#define pii pair<int,int>

const ll inf=1e17;
const int N = 200000 + 5;
const int M = 20000 + 5;

const int demension=2;//二维

struct P{
    int pos[demension],c,id;
}hotel[N];
P kdtree[N];
double var[demension];//方差
int split[N];//i为根的子树 分裂方式为第split[i]维
int cmpDem;//以第cmpDem维作比较
bool cmp(const P&a,const P&b){
    return a.pos[cmpDem]<b.pos[cmpDem];
}

void build(int l,int r){
    if(l<r){
        int mid=(l+r)/2;
        //计算每一维方差
        for(int i=0;i<demension;++i){
            double ave=0;//均值
            for(int j=l;j<=r;++j){
                ave+=hotel[j].pos[i];
            }
            ave/=(r-l+1);
            var[i]=0;//方差
            for(int j=l;j<=r;++j){
                var[i]+=(hotel[j].pos[i]-ave)*(hotel[j].pos[i]-ave);
            }
            var[i]/=(r-l+1);
        }
        //更新mid为树根时 分裂方法为第几维
        split[mid]=-1;
        double maxVar=-1;
        for(int i=0;i<demension;++i){//找方差最大的维
            if(var[i]>maxVar){
                maxVar=var[i];
                split[mid]=i;
            }
        }
        //以第mid个元素为中心 排序
        cmpDem=split[mid];
        nth_element(hotel+l,hotel+mid,hotel+r+1,cmp);
        //左右子树
        build(l,mid-1);
        build(mid+1,r);
    }
}

int ansIndex;
ll ansDis;//ansDis=欧几里得距离^2
void query(int l,int r,P op){
    if(l>r){
        return;
    }
    int mid=(l+r)/2;
    //op到根节点距离
    ll dis=0;
    for(int i=0;i<demension;++i){
        dis+=(ll)(op.pos[i]-hotel[mid].pos[i])*(op.pos[i]-hotel[mid].pos[i]);
    }
    //更新ans
    if(hotel[mid].c<=op.c){
        if(dis==ansDis&&hotel[mid].id<hotel[ansIndex].id){
            ansIndex=mid;
        }
        if(dis<ansDis){
            ansDis=dis;
            ansIndex=mid;
        }
    }
    int d=split[mid];
    ll radius=(ll)(op.pos[d]-hotel[mid].pos[d])*(op.pos[d]-hotel[mid].pos[d]);//到分裂平面距离
    if(op.pos[d]<hotel[mid].pos[d]){
        query(l,mid-1,op);
        if(ansDis>=radius){
            query(mid+1,r,op);
        }
    }
    else{
        query(mid+1,r,op);
        if(ansDis>=radius){
            query(l,mid-1,op);
        }
    }
}

int main()
{
    //freopen("/home/lu/Documents/r.txt","r",stdin);
    int T;
    scanf("%d",&T);
    while(T--){
        int n,m;
        scanf("%d%d",&n,&m);
        for(int i=0;i<n;++i){
            scanf("%d%d%d",&hotel[i].pos[0],&hotel[i].pos[1],&hotel[i].c);
            hotel[i].id=i;
        }
        build(0,n-1);
        P p;
        for(int i=0;i<m;++i){
            scanf("%d%d%d",&p.pos[0],&p.pos[1],&p.c);
            ansDis=inf;
            ansIndex=-1;
            query(0,n-1,p);
            printf("%d %d %d\n",hotel[ansIndex].pos[0],hotel[ansIndex].pos[1],hotel[ansIndex].c);
        }

    }
    return 0;
}

猜你喜欢

转载自blog.csdn.net/Fire_to_cheat_/article/details/83421046