HDU 4347 The Closest M Points

The Closest M Points
Time Limit: 8000MS Memory Limit: 98304KB 64bit IO Format: %I64d & %I64u

Submit Status

Description
The course of Software Design and Development Practice is objectionable. ZLC is facing a serious problem .There are many points in K-dimensional space .Given a point. ZLC need to find out the closest m points. Euclidean distance is used as the distance metric between two points. The Euclidean distance between points p and q is the length of the line segment connecting them.In Cartesian coordinates, if p = (p 1, p 2,…, p n) and q = (q 1, q 2,…, q n) are two points in Euclidean n-space, then the distance from p to q, or from q to p is given by:

Can you help him solve this problem?

Input
In the first line of the text file .there are two non-negative integers n and K. They denote respectively: the number of points, 1 <= n <= 50000, and the number of Dimensions,1 <= K <= 5. In each of the following n lines there is written k integers, representing the coordinates of a point. This followed by a line with one positive integer t, representing the number of queries,1 <= t <=10000.each query contains two lines. The k integers in the first line represent the given point. In the second line, there is one integer m, the number of closest points you should find,1 <= m <=10. The absolute value of all the coordinates will not be more than 10000.
There are multiple test cases. Process to end of file.

Output
For each query, output m+1 lines:
The first line saying :”the closest m points are:” where m is the number of the points.
The following m lines representing m points ,in accordance with the order from near to far
It is guaranteed that the answer can only be formed in one ways. The distances from the given point to all the nearest m+1 points are different. That means input like this:
2 2
1 1
3 3
1
2 2
1
will not exist.

Sample Input

3 2
1 1
1 3
3 4
2
2 3
2
2 3
1

Sample Output

the closest 2 points are:
1 3
3 4
the closest 1 points are:
1 3

这道题是一道可以用kd-tree来解决的题,关于kd-tree的介绍有很多,我扼要的介绍一下,kd-tree是一种二叉树,用来对很多数据进行划分,而这些数据都具有k种属性。例如:我们要分类一组气象数据,而每个数据都具有多种属性,如:气压,温度,湿度。在kd-tree中我们轮流使用这些属性来分类这些数据。还是用气象数据的例子,我们在kd-tree的第一层用气压分类,第二层用温度分类,第三层用湿度分类。

假设我们在某个节点使用第k个属性kt” role=”presentation” style=”position: relative;”>ktkt那我们进入该节点的左子树继续分类,否则进入右子树。我们举个例子, 这个例子参考了Christopher G. Healey’s Advanced data structure notes:

名字(name) 身高(ht) 体重(wt)
瞌睡虫(Sleepy) 36 48
开心果(Happy) 34 52
博士(Doc) 38 51
糊涂蛋(Dopey) 37 54
爱生气(Grumpy) 32 55
喷嚏精(Sneezy) 35 46
害羞鬼(Bashful) 33 50
白雪公主(Ms.White) 65 98

我们轮流使用两个属性身高和体重来分类数据,我们首先使用身高属性,通常我们可以使用 身高的中位数对应的值作为划分值,但这里我们为了简便,我们就按顺序选择每条数据去建立kd-tree,例如第一条数据是瞌睡虫,使用的是身高属性,因此kd-tree的根节点就是ht:36, 接着添加第二条数据:开心果。因为开心果的身高为34,所以进入到根节点的左子树,由于是第二层,所以我们应该使用体重属性,所以kd-tree,因此根节点的左子树的根节点为wt:52,同理我们添加第三个数据:博士,博士的身高高于36,因此进入到右子树,然后我们使用体重属性,因此根节点的右子树的根节点为wt:51. 如图所示:
这里写图片描述
我们接着添加数据,当所有的数据都添加后,我们就得到一棵完整的kd-tree,然后我们重新将所有数据通过这棵kd-tree,并将所有的的数据存储在kd-tree的叶子节点上,如图所示:
这里写图片描述
一般的kd-tree都是将数据存储在叶子节点上,但是对于我们这一题,我按照这种方式实现后,提交结果超时,后来我参考了别人的做法(http://blog.csdn.net/wxfwxf328/article/details/8158187),在他的实现中,直接将数据存在每个节点上,即做分类的节点(根节点及其他非叶子 节点),也存储了数据。即如图:
这里写图片描述这种做法在这道题上取得了较好的效果。知道了原理,接下来我们就来看下具体的实现方法。

#include<queue>   
#include<algorithm>  
#include<string.h>
using namespace std;

const int N = 50000;
short currentDim; //ys: which dimension we are currently working on
#define square(x) (x)*(x)
int n;
short K, t;
    
    
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10

N表示的是可能输入的最多的点数,currentDim表示的是当前我们考虑的是哪一维,n表示具体的测试用例中的点数,K表示具体的测试用例的位数,t表示测试用例数。

struct point
{
    short pos[5];
    bool flag; //ys: whether this is a valid point
    point()
    {
        flag = false;
    }
    bool operator < (const point &a) const
    {
        return pos[currentDim] < a.pos[currentDim]; 
    }
};

point input_v[N];
priority_queue<pair<double, point>> pq;
point kd_tree[4*N];
    
    
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17

因为题目告知最多为5维数据,所以我们定义pos为5维数组,flag表示当前这个点是不是一个有效点,只有一个点被赋予了具体的值之后,我们才能说它是一个有效点,所以最初都会被初始化为false。接下来是一个结构体内的运算符重载,这点很有用,因为定义了这个运算符重载之后,我们就可以直接比较两个point来了,而这个结果还会一来与当前使用的是哪一维有关,因为当前使用的是哪一维,我们就用哪一维进行比较。有了point结构,我们就可以定义存放输入数据的数组input_v, 我们使用了priority queue来帮我们完成对于pair的排序,而pair包括两部分,第一部分是一个double,第二部分是一个点,而第一部分的double表示的就是第二部分这个点到我们查询的这个点间的距离。然后我们定义了kd-tree, 我们注意到它包含的点数为4N,因为我们使用的是数组来实现kd-tree, 而我们的root的index为1,所以我们可知如果节点的index为x,那么它的左儿子节点的index为2x,它的右儿子节点的index为2x+1,即使当前节点是kd-tree中的最后一个节点,但我们并不知道,所以我们仍要检查它的左儿子和右儿子,直到发现它的两个儿子的flag都为false,我们才停止进一步检查。那么最坏的情况下我们要检测的节点的最大索引值是多少呢?如下图所示:
这里写图片描述
我们假设有四个点,我们每次选择当前维上排在中间的元素作为划分的节点,如果总的点数为偶数,那么就会有两个点排在中间,这是我们选择较小的那个作为划分的节点,因此如图所示,根节点把所有数据分为两类,比它小的有一个,比它大的有两个,图中每个节点中的数字表示该节点在数组中的存储位置,因为根节点的位置为1,那么它的左儿子的位置就应该为2,右儿子的位置就应该为3,由于右边有两个节点,所以我们要继续划分,找出当前维上排在中间的节点,因为只有两个节点,所以两个节点都算是排在中间的节点,因此我们选择较小的那个作为划分节点,存储在位置3上,而另外那个节点比这个划分节点大,所以成为它的右儿子,存储在位置7上,虽然这个节点已经是最后一个节点了,但是我们的程序判断不再进一步构建子树的条件是当前节点为无效节点(即flag=false), 但我们当前的节点是有效节点,所以程序仍会继续检测它的左子树和右子树,因此位置14和15会被检测到(当然这个过程可以优化,从而避免为位置14和15分配空间)。此时,我们发现4个节点,但在最坏的情况下,需要检测到位置15,由于我们的数组是从0开始的(虽然根节点是从1开始),所以我们的数组大小为16,是节点数的4倍,所以这就是为什么kd-tree数组的大小为4N。

inline void construct_kdTree(int n, short K, int p, int r, int index, short depth)
{
    if(p <= r)
    {
        currentDim = depth % K;
        int mid = (p + r) / 2;
        nth_element(input_v+p, input_v+mid, input_v+r+1);
        kd_tree[index] = input_v[mid];
        kd_tree[index].flag = true;
        construct_kdTree(n, K, p, mid-1, 2*index, depth+1);
        construct_kdTree(n, K, mid+1, r, 2*index+1, depth+1);
    }
}
    
    
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13

下面就是创建kd-tree了,因为用递归的方式创建,所以先还是判断一下对应的区间的开始p和结尾r位置,如果开始位置在结尾位置之后,那么这个区间不成立,否则就应该创建对应的节点,首先找出当前使用的是哪一维,因为每一维是轮流使用的,所以我们通过当前节点所处的深度模上总的维数来得到当前使用的维数,不难发现同一深度的节点使用的是同一维。然后我们计算出当前区间的中间位置,如果有两个中间位置,则取较小的一个。接着我们使用stl中的nth_element,它的作用是把第n大的元素摆在第n个位置上,以我们程序中的nth_element ( input_v+p, input_v+mid, input_v+r+1);为例,我们区间input_v+p到input_v+r这段区间中的第mid-p+1大的元素放到input_v+mid这个位置上,同时保证在input_v+mid左侧的元素都小于等于它,input_v+mid右侧的元素都大于等于它,尽管如此,这些元素并不一定排好了,例如可能是2 1 3 5 4, 2,1 都小于3, 而5,4都大于3。把中间这个元素放到input_v+mid这个位置上了之后,我们就取这个中间元素作为分类节点的值,然后把flag设为true,表示这是一个有效的节点。接着我们继续递归地构建左子树和右子树。

inline void query_kdTree(point p, short closestM, short K, int index, short depth)
{
    //if(index >= 2*N || !kd_tree[index].flag)
    if(!kd_tree[index].flag)
        return;
    pair<double, point> current_node(0, kd_tree[index]);
    for(int i = 0; i < K; i++)
    {
        current_node.first += square(p.pos[i]-current_node.second.pos[i]);
    }

    short idm = depth % K;
    short flag = 0; //ys: whether we need to explore other side of the split
    int lchild = 2*index;
    int rchild = 2*index+1;

    if(kd_tree[index].pos[idm] <= p.pos[idm])
    {
        query_kdTree(p, closestM, K, rchild, depth+1);
    }
    else
    {
        query_kdTree(p, closestM, K, lchild, depth+1);
    }

    if(pq.size() < closestM)
    {
        pq.push(current_node);
        flag = 1;
    }
    else
    {
        if(current_node.first < pq.top().first)
        {
            pq.pop();
            pq.push(current_node);
        }
        if(square(p.pos[idm] - kd_tree[index].pos[idm]) < pq.top().first)
        {
            flag = 1;
        }
    }

    if(flag)
    {
        if(kd_tree[index].pos[idm] <= p.pos[idm])
            query_kdTree(p, closestM, K, lchild, depth+1);
        else
            query_kdTree(p, closestM, K, rchild, depth+1);
    }
}
    
    
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51

这部分就是kd-tree的查询了,首先一开始,我们判断是否当前节点是flag是否是true,如果是false则说明当前节点是无效节点,所以直接返回,否则进一步开始查询,接着使用到了stl中的pair,这里使用pair主要是想将point和该point与查询的point之间的距离绑定在一起。然后我们开始计算当前point与查询point之间的距离,这个地方本应该要开方才是距离,但由于我们用它来比较大小,所以不开方也不会影响大小关系,所以为了节省时间,此处就不进行开方运算。

接着我们计算当前使用的是哪一维,这个地方我们使用了一个局部变量来存储,而没有使用之前申明的全局变量,是因为这个变量将在递归调用前后多次使用,如果使用全局变量,在递归调用过程中这个值会被改变,因此,当程序从递归调用中返回时就与递归调用前的值不一致了,而这将会导致错误。接着的变量flag表示的是是否我们还需要搜索当前节点的另一分支。具体在什么情况下需要搜索另一分支,我们将在后面介绍。接着我们计算当前节点的左右儿子的位置。然后我们开始比较,如果查询point的值大于等于当前point的值,那么进入当前point的右子树继续搜索,否则进入当前point的左子树继续搜索。当子树搜索完毕后,所有结果都存储在名为pq的priority queue中,如果pq的大小不到closestM,例如我们要找出离查询point最近的5个点,而pq中现在只有3个点,那么我们肯定还要想pq中加点,所以首先我们将当前point加进去,然后将flag置为1,说明我们还需要搜寻当前节点的另一分支。另外方面如果pq的大小已经超过了closestM,那么我们就把当前point与查询point之间的距离和pq中的最大距离做比较,如果比pq中的最大距离小,那么我们就弹出pq中的最大距离点,把当前point加入到pq中(pq是大顶堆)。但是否存在查询point虽然被划分到左子树,而实际上它离右子树上的点更近呢,答案是肯定的,如图所示:
这里写图片描述
查询point A虽然和point B,point C分为一类,但实际上离他最近的是分在另外一类的point D。那么如何判断是否可能存在这种情况呢?其实很简单,我们检测查询point与当前point在当前维上的差距,即:我们图中的point A到划分直线的距离,如果这个距离大于pq中的最大距离,那么我们就不用检测划分直线另一侧的点了,因为划分直线另一侧的点到point A的距离必然大于等于划分直线到point A的距离。否则,我们就需要检测划分直线另一侧的点,即:当前point的另一分支。所以接下来的代码中,如果flag为true,那么我们就检测当前point的另一分支。

int main()
{
    while(scanf("%d%hd",&n,&K)!=EOF)  
    {
        for(int i = 0; i < n; i++)
        {
            for(short j = 0; j < K; j++)
            {
                scanf("%hd", &input_v[i].pos[j]);
            }
        }
        memset(kd_tree, 0, sizeof(point)*4*N);
        construct_kdTree(n, K, 0, n-1, 1, 0);

        scanf("%hd", &t);
        for(short c = 0; c < t; c++)
        {
            short closestM;
            point tmpv;
            for(short j = 0; j < K; j++)
            {
                scanf("%hd", &tmpv.pos[j]);
            }
            scanf("%hd", &closestM);
            query_kdTree(tmpv, closestM, K, 1, 0);

            printf("the closest %hd points are:\n", closestM);
            point pt[10];
            for(short i = 0; !pq.empty(); i++)
            {
                pt[i] = pq.top().second;
                pq.pop();
            }
            for(short i = closestM-1; i >= 0; i--)
            {
                for(short j = 0; j < K; j++)
                {
                    printf("%hd", pt[i].pos[j]);
                    if(j == K-1)
                        printf("\n");
                    else
                        printf(" ");
                }
            }
        }
    }
    return 0;
}
    
    
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48

主函数就比较简单了,首先读入有多少个点,几维,然后把每个点读入到input_v数组中,因为有多组测试,每一组使用不同的kd-tree,所以现将kd-tree对应内存置0,然后构建kd-tree,然后读入要测试几个点,然后是具体的查询point,然后是要找出距离查询point最近的几个points。然后调用query_kdTree进行查询。得到结果后,把这些点按照距离查询point的距离由小到大打印出来,值得一提的是pq是大顶堆,所以我们从中弹出来的point距离查询point的距离是由大到小,所以我们使用了一个额外的数组pt,用于存放从pq中弹出的point,然后我们在按从后往前的顺序将结果打出。
下面是完整程序代码:

#include<queue>   
#include<algorithm>  
#include<string.h>
using namespace std;

const int N = 50000;
short currentDim; //ys: which dimension we are currently working on
#define square(x) (x)*(x)
int n;
short K, t;

struct point
{
    short pos[5];
    bool flag; //ys: whether this is a valid point
    point()
    {
        flag = false;
    }
    bool operator < (const point &a) const
    {
        return pos[currentDim] < a.pos[currentDim]; 
    }
};

point input_v[N];
priority_queue<pair<double, point>> pq;
point kd_tree[4*N];
//point kd_tree[2*N];

inline void construct_kdTree(int n, short K, int p, int r, int index, short depth)
{
    if(p <= r)
    {
        currentDim = depth % K;
        int mid = (p + r) / 2;
        nth_element(input_v+p, input_v+mid, input_v+r+1);
        kd_tree[index] = input_v[mid];
        kd_tree[index].flag = true;
        construct_kdTree(n, K, p, mid-1, 2*index, depth+1);
        construct_kdTree(n, K, mid+1, r, 2*index+1, depth+1);
    }
}

inline void query_kdTree(point p, short closestM, short K, int index, short depth)
{
    //if(index >= 2*N || !kd_tree[index].flag)
    if(!kd_tree[index].flag)
        return;
    pair<double, point> current_node(0, kd_tree[index]);
    for(int i = 0; i < K; i++)
    {
        current_node.first += square(p.pos[i]-current_node.second.pos[i]);
    }

    short idm = depth % K;
    short flag = 0; //ys: whether we need to explore other side of the split
    int lchild = 2*index;
    int rchild = 2*index+1;

    if(kd_tree[index].pos[idm] <= p.pos[idm])
    {
        query_kdTree(p, closestM, K, rchild, depth+1);
    }
    else
    {
        query_kdTree(p, closestM, K, lchild, depth+1);
    }

    if(pq.size() < closestM)
    {
        pq.push(current_node);
        flag = 1;
    }
    else
    {
        if(current_node.first < pq.top().first)
        {
            pq.pop();
            pq.push(current_node);
        }
        if(square(p.pos[idm] - kd_tree[index].pos[idm]) < pq.top().first)
        {
            flag = 1;
        }
    }

    if(flag)
    {
        if(kd_tree[index].pos[idm] <= p.pos[idm])
            query_kdTree(p, closestM, K, lchild, depth+1);
        else
            query_kdTree(p, closestM, K, rchild, depth+1);
    }
}

int main()
{
    while(scanf("%d%hd",&n,&K)!=EOF)  
    {
        for(int i = 0; i < n; i++)
        {
            for(short j = 0; j < K; j++)
            {
                scanf("%hd", &input_v[i].pos[j]);
            }
        }
        //memset(kd_tree, 0, sizeof(point)*2*N);
        memset(kd_tree, 0, sizeof(point)*4*N);
        construct_kdTree(n, K, 0, n-1, 1, 0);

        scanf("%hd", &t);
        for(short c = 0; c < t; c++)
        {
            short closestM;
            point tmpv;
            for(short j = 0; j < K; j++)
            {
                scanf("%hd", &tmpv.pos[j]);
            }
            scanf("%hd", &closestM);
            query_kdTree(tmpv, closestM, K, 1, 0);

            printf("the closest %hd points are:\n", closestM);
            point pt[10];
            for(short i = 0; !pq.empty(); i++)
            {
                pt[i] = pq.top().second;
                pq.pop();
            }
            for(short i = closestM-1; i >= 0; i--)
            {
                for(short j = 0; j < K; j++)
                {
                    printf("%hd", pt[i].pos[j]);
                    if(j == K-1)
                        printf("\n");
                    else
                        printf(" ");
                }
            }
        }
    }
    return 0;
}
    
    
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
  • 96
  • 97
  • 98
  • 99
  • 100
  • 101
  • 102
  • 103
  • 104
  • 105
  • 106
  • 107
  • 108
  • 109
  • 110
  • 111
  • 112
  • 113
  • 114
  • 115
  • 116
  • 117
  • 118
  • 119
  • 120
  • 121
  • 122
  • 123
  • 124
  • 125
  • 126
  • 127
  • 128
  • 129
  • 130
  • 131
  • 132
  • 133
  • 134
  • 135
  • 136
  • 137
  • 138
  • 139
  • 140
  • 141
  • 142
  • 143
  • 144
  • 145
                    <link rel="stylesheet" href="https://csdnimg.cn/release/phoenix/template/css/markdown_views-ea0013b516.css">
                        </div>

The Closest M Points
Time Limit: 8000MS Memory Limit: 98304KB 64bit IO Format: %I64d & %I64u

Submit Status

Description
The course of Software Design and Development Practice is objectionable. ZLC is facing a serious problem .There are many points in K-dimensional space .Given a point. ZLC need to find out the closest m points. Euclidean distance is used as the distance metric between two points. The Euclidean distance between points p and q is the length of the line segment connecting them.In Cartesian coordinates, if p = (p 1, p 2,…, p n) and q = (q 1, q 2,…, q n) are two points in Euclidean n-space, then the distance from p to q, or from q to p is given by:

Can you help him solve this problem?

Input
In the first line of the text file .there are two non-negative integers n and K. They denote respectively: the number of points, 1 <= n <= 50000, and the number of Dimensions,1 <= K <= 5. In each of the following n lines there is written k integers, representing the coordinates of a point. This followed by a line with one positive integer t, representing the number of queries,1 <= t <=10000.each query contains two lines. The k integers in the first line represent the given point. In the second line, there is one integer m, the number of closest points you should find,1 <= m <=10. The absolute value of all the coordinates will not be more than 10000.
There are multiple test cases. Process to end of file.

Output
For each query, output m+1 lines:
The first line saying :”the closest m points are:” where m is the number of the points.
The following m lines representing m points ,in accordance with the order from near to far
It is guaranteed that the answer can only be formed in one ways. The distances from the given point to all the nearest m+1 points are different. That means input like this:
2 2
1 1
3 3
1
2 2
1
will not exist.

Sample Input

3 2
1 1
1 3
3 4
2
2 3
2
2 3
1

Sample Output

the closest 2 points are:
1 3
3 4
the closest 1 points are:
1 3

这道题是一道可以用kd-tree来解决的题,关于kd-tree的介绍有很多,我扼要的介绍一下,kd-tree是一种二叉树,用来对很多数据进行划分,而这些数据都具有k种属性。例如:我们要分类一组气象数据,而每个数据都具有多种属性,如:气压,温度,湿度。在kd-tree中我们轮流使用这些属性来分类这些数据。还是用气象数据的例子,我们在kd-tree的第一层用气压分类,第二层用温度分类,第三层用湿度分类。

假设我们在某个节点使用第k个属性kt” role=”presentation” style=”position: relative;”>ktkt那我们进入该节点的左子树继续分类,否则进入右子树。我们举个例子, 这个例子参考了Christopher G. Healey’s Advanced data structure notes:

名字(name) 身高(ht) 体重(wt)
瞌睡虫(Sleepy) 36 48
开心果(Happy) 34 52
博士(Doc) 38 51
糊涂蛋(Dopey) 37 54
爱生气(Grumpy) 32 55
喷嚏精(Sneezy) 35 46
害羞鬼(Bashful) 33 50
白雪公主(Ms.White) 65 98

我们轮流使用两个属性身高和体重来分类数据,我们首先使用身高属性,通常我们可以使用 身高的中位数对应的值作为划分值,但这里我们为了简便,我们就按顺序选择每条数据去建立kd-tree,例如第一条数据是瞌睡虫,使用的是身高属性,因此kd-tree的根节点就是ht:36, 接着添加第二条数据:开心果。因为开心果的身高为34,所以进入到根节点的左子树,由于是第二层,所以我们应该使用体重属性,所以kd-tree,因此根节点的左子树的根节点为wt:52,同理我们添加第三个数据:博士,博士的身高高于36,因此进入到右子树,然后我们使用体重属性,因此根节点的右子树的根节点为wt:51. 如图所示:
这里写图片描述
我们接着添加数据,当所有的数据都添加后,我们就得到一棵完整的kd-tree,然后我们重新将所有数据通过这棵kd-tree,并将所有的的数据存储在kd-tree的叶子节点上,如图所示:
这里写图片描述
一般的kd-tree都是将数据存储在叶子节点上,但是对于我们这一题,我按照这种方式实现后,提交结果超时,后来我参考了别人的做法(http://blog.csdn.net/wxfwxf328/article/details/8158187),在他的实现中,直接将数据存在每个节点上,即做分类的节点(根节点及其他非叶子 节点),也存储了数据。即如图:
这里写图片描述这种做法在这道题上取得了较好的效果。知道了原理,接下来我们就来看下具体的实现方法。

#include<queue>   
#include<algorithm>  
#include<string.h>
using namespace std;

const int N = 50000;
short currentDim; //ys: which dimension we are currently working on
#define square(x) (x)*(x)
int n;
short K, t;
  
  
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10

N表示的是可能输入的最多的点数,currentDim表示的是当前我们考虑的是哪一维,n表示具体的测试用例中的点数,K表示具体的测试用例的位数,t表示测试用例数。

struct point
{
    short pos[5];
    bool flag; //ys: whether this is a valid point
    point()
    {
        flag = false;
    }
    bool operator < (const point &a) const
    {
        return pos[currentDim] < a.pos[currentDim]; 
    }
};

point input_v[N];
priority_queue<pair<double, point>> pq;
point kd_tree[4*N];
  
  
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17

因为题目告知最多为5维数据,所以我们定义pos为5维数组,flag表示当前这个点是不是一个有效点,只有一个点被赋予了具体的值之后,我们才能说它是一个有效点,所以最初都会被初始化为false。接下来是一个结构体内的运算符重载,这点很有用,因为定义了这个运算符重载之后,我们就可以直接比较两个point来了,而这个结果还会一来与当前使用的是哪一维有关,因为当前使用的是哪一维,我们就用哪一维进行比较。有了point结构,我们就可以定义存放输入数据的数组input_v, 我们使用了priority queue来帮我们完成对于pair的排序,而pair包括两部分,第一部分是一个double,第二部分是一个点,而第一部分的double表示的就是第二部分这个点到我们查询的这个点间的距离。然后我们定义了kd-tree, 我们注意到它包含的点数为4N,因为我们使用的是数组来实现kd-tree, 而我们的root的index为1,所以我们可知如果节点的index为x,那么它的左儿子节点的index为2x,它的右儿子节点的index为2x+1,即使当前节点是kd-tree中的最后一个节点,但我们并不知道,所以我们仍要检查它的左儿子和右儿子,直到发现它的两个儿子的flag都为false,我们才停止进一步检查。那么最坏的情况下我们要检测的节点的最大索引值是多少呢?如下图所示:
这里写图片描述
我们假设有四个点,我们每次选择当前维上排在中间的元素作为划分的节点,如果总的点数为偶数,那么就会有两个点排在中间,这是我们选择较小的那个作为划分的节点,因此如图所示,根节点把所有数据分为两类,比它小的有一个,比它大的有两个,图中每个节点中的数字表示该节点在数组中的存储位置,因为根节点的位置为1,那么它的左儿子的位置就应该为2,右儿子的位置就应该为3,由于右边有两个节点,所以我们要继续划分,找出当前维上排在中间的节点,因为只有两个节点,所以两个节点都算是排在中间的节点,因此我们选择较小的那个作为划分节点,存储在位置3上,而另外那个节点比这个划分节点大,所以成为它的右儿子,存储在位置7上,虽然这个节点已经是最后一个节点了,但是我们的程序判断不再进一步构建子树的条件是当前节点为无效节点(即flag=false), 但我们当前的节点是有效节点,所以程序仍会继续检测它的左子树和右子树,因此位置14和15会被检测到(当然这个过程可以优化,从而避免为位置14和15分配空间)。此时,我们发现4个节点,但在最坏的情况下,需要检测到位置15,由于我们的数组是从0开始的(虽然根节点是从1开始),所以我们的数组大小为16,是节点数的4倍,所以这就是为什么kd-tree数组的大小为4N。

inline void construct_kdTree(int n, short K, int p, int r, int index, short depth)
{
    if(p <= r)
    {
        currentDim = depth % K;
        int mid = (p + r) / 2;
        nth_element(input_v+p, input_v+mid, input_v+r+1);
        kd_tree[index] = input_v[mid];
        kd_tree[index].flag = true;
        construct_kdTree(n, K, p, mid-1, 2*index, depth+1);
        construct_kdTree(n, K, mid+1, r, 2*index+1, depth+1);
    }
}
  
  
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13

下面就是创建kd-tree了,因为用递归的方式创建,所以先还是判断一下对应的区间的开始p和结尾r位置,如果开始位置在结尾位置之后,那么这个区间不成立,否则就应该创建对应的节点,首先找出当前使用的是哪一维,因为每一维是轮流使用的,所以我们通过当前节点所处的深度模上总的维数来得到当前使用的维数,不难发现同一深度的节点使用的是同一维。然后我们计算出当前区间的中间位置,如果有两个中间位置,则取较小的一个。接着我们使用stl中的nth_element,它的作用是把第n大的元素摆在第n个位置上,以我们程序中的nth_element ( input_v+p, input_v+mid, input_v+r+1);为例,我们区间input_v+p到input_v+r这段区间中的第mid-p+1大的元素放到input_v+mid这个位置上,同时保证在input_v+mid左侧的元素都小于等于它,input_v+mid右侧的元素都大于等于它,尽管如此,这些元素并不一定排好了,例如可能是2 1 3 5 4, 2,1 都小于3, 而5,4都大于3。把中间这个元素放到input_v+mid这个位置上了之后,我们就取这个中间元素作为分类节点的值,然后把flag设为true,表示这是一个有效的节点。接着我们继续递归地构建左子树和右子树。

inline void query_kdTree(point p, short closestM, short K, int index, short depth)
{
    //if(index >= 2*N || !kd_tree[index].flag)
    if(!kd_tree[index].flag)
        return;
    pair<double, point> current_node(0, kd_tree[index]);
    for(int i = 0; i < K; i++)
    {
        current_node.first += square(p.pos[i]-current_node.second.pos[i]);
    }

    short idm = depth % K;
    short flag = 0; //ys: whether we need to explore other side of the split
    int lchild = 2*index;
    int rchild = 2*index+1;

    if(kd_tree[index].pos[idm] <= p.pos[idm])
    {
        query_kdTree(p, closestM, K, rchild, depth+1);
    }
    else
    {
        query_kdTree(p, closestM, K, lchild, depth+1);
    }

    if(pq.size() < closestM)
    {
        pq.push(current_node);
        flag = 1;
    }
    else
    {
        if(current_node.first < pq.top().first)
        {
            pq.pop();
            pq.push(current_node);
        }
        if(square(p.pos[idm] - kd_tree[index].pos[idm]) < pq.top().first)
        {
            flag = 1;
        }
    }

    if(flag)
    {
        if(kd_tree[index].pos[idm] <= p.pos[idm])
            query_kdTree(p, closestM, K, lchild, depth+1);
        else
            query_kdTree(p, closestM, K, rchild, depth+1);
    }
}
  
  
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51

这部分就是kd-tree的查询了,首先一开始,我们判断是否当前节点是flag是否是true,如果是false则说明当前节点是无效节点,所以直接返回,否则进一步开始查询,接着使用到了stl中的pair,这里使用pair主要是想将point和该point与查询的point之间的距离绑定在一起。然后我们开始计算当前point与查询point之间的距离,这个地方本应该要开方才是距离,但由于我们用它来比较大小,所以不开方也不会影响大小关系,所以为了节省时间,此处就不进行开方运算。

接着我们计算当前使用的是哪一维,这个地方我们使用了一个局部变量来存储,而没有使用之前申明的全局变量,是因为这个变量将在递归调用前后多次使用,如果使用全局变量,在递归调用过程中这个值会被改变,因此,当程序从递归调用中返回时就与递归调用前的值不一致了,而这将会导致错误。接着的变量flag表示的是是否我们还需要搜索当前节点的另一分支。具体在什么情况下需要搜索另一分支,我们将在后面介绍。接着我们计算当前节点的左右儿子的位置。然后我们开始比较,如果查询point的值大于等于当前point的值,那么进入当前point的右子树继续搜索,否则进入当前point的左子树继续搜索。当子树搜索完毕后,所有结果都存储在名为pq的priority queue中,如果pq的大小不到closestM,例如我们要找出离查询point最近的5个点,而pq中现在只有3个点,那么我们肯定还要想pq中加点,所以首先我们将当前point加进去,然后将flag置为1,说明我们还需要搜寻当前节点的另一分支。另外方面如果pq的大小已经超过了closestM,那么我们就把当前point与查询point之间的距离和pq中的最大距离做比较,如果比pq中的最大距离小,那么我们就弹出pq中的最大距离点,把当前point加入到pq中(pq是大顶堆)。但是否存在查询point虽然被划分到左子树,而实际上它离右子树上的点更近呢,答案是肯定的,如图所示:
这里写图片描述
查询point A虽然和point B,point C分为一类,但实际上离他最近的是分在另外一类的point D。那么如何判断是否可能存在这种情况呢?其实很简单,我们检测查询point与当前point在当前维上的差距,即:我们图中的point A到划分直线的距离,如果这个距离大于pq中的最大距离,那么我们就不用检测划分直线另一侧的点了,因为划分直线另一侧的点到point A的距离必然大于等于划分直线到point A的距离。否则,我们就需要检测划分直线另一侧的点,即:当前point的另一分支。所以接下来的代码中,如果flag为true,那么我们就检测当前point的另一分支。

int main()
{
    while(scanf("%d%hd",&n,&K)!=EOF)  
    {
        for(int i = 0; i < n; i++)
        {
            for(short j = 0; j < K; j++)
            {
                scanf("%hd", &input_v[i].pos[j]);
            }
        }
        memset(kd_tree, 0, sizeof(point)*4*N);
        construct_kdTree(n, K, 0, n-1, 1, 0);

        scanf("%hd", &t);
        for(short c = 0; c < t; c++)
        {
            short closestM;
            point tmpv;
            for(short j = 0; j < K; j++)
            {
                scanf("%hd", &tmpv.pos[j]);
            }
            scanf("%hd", &closestM);
            query_kdTree(tmpv, closestM, K, 1, 0);

            printf("the closest %hd points are:\n", closestM);
            point pt[10];
            for(short i = 0; !pq.empty(); i++)
            {
                pt[i] = pq.top().second;
                pq.pop();
            }
            for(short i = closestM-1; i >= 0; i--)
            {
                for(short j = 0; j < K; j++)
                {
                    printf("%hd", pt[i].pos[j]);
                    if(j == K-1)
                        printf("\n");
                    else
                        printf(" ");
                }
            }
        }
    }
    return 0;
}
  
  
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48

主函数就比较简单了,首先读入有多少个点,几维,然后把每个点读入到input_v数组中,因为有多组测试,每一组使用不同的kd-tree,所以现将kd-tree对应内存置0,然后构建kd-tree,然后读入要测试几个点,然后是具体的查询point,然后是要找出距离查询point最近的几个points。然后调用query_kdTree进行查询。得到结果后,把这些点按照距离查询point的距离由小到大打印出来,值得一提的是pq是大顶堆,所以我们从中弹出来的point距离查询point的距离是由大到小,所以我们使用了一个额外的数组pt,用于存放从pq中弹出的point,然后我们在按从后往前的顺序将结果打出。
下面是完整程序代码:

#include<queue>   
#include<algorithm>  
#include<string.h>
using namespace std;

const int N = 50000;
short currentDim; //ys: which dimension we are currently working on
#define square(x) (x)*(x)
int n;
short K, t;

struct point
{
    short pos[5];
    bool flag; //ys: whether this is a valid point
    point()
    {
        flag = false;
    }
    bool operator < (const point &a) const
    {
        return pos[currentDim] < a.pos[currentDim]; 
    }
};

point input_v[N];
priority_queue<pair<double, point>> pq;
point kd_tree[4*N];
//point kd_tree[2*N];

inline void construct_kdTree(int n, short K, int p, int r, int index, short depth)
{
    if(p <= r)
    {
        currentDim = depth % K;
        int mid = (p + r) / 2;
        nth_element(input_v+p, input_v+mid, input_v+r+1);
        kd_tree[index] = input_v[mid];
        kd_tree[index].flag = true;
        construct_kdTree(n, K, p, mid-1, 2*index, depth+1);
        construct_kdTree(n, K, mid+1, r, 2*index+1, depth+1);
    }
}

inline void query_kdTree(point p, short closestM, short K, int index, short depth)
{
    //if(index >= 2*N || !kd_tree[index].flag)
    if(!kd_tree[index].flag)
        return;
    pair<double, point> current_node(0, kd_tree[index]);
    for(int i = 0; i < K; i++)
    {
        current_node.first += square(p.pos[i]-current_node.second.pos[i]);
    }

    short idm = depth % K;
    short flag = 0; //ys: whether we need to explore other side of the split
    int lchild = 2*index;
    int rchild = 2*index+1;

    if(kd_tree[index].pos[idm] <= p.pos[idm])
    {
        query_kdTree(p, closestM, K, rchild, depth+1);
    }
    else
    {
        query_kdTree(p, closestM, K, lchild, depth+1);
    }

    if(pq.size() < closestM)
    {
        pq.push(current_node);
        flag = 1;
    }
    else
    {
        if(current_node.first < pq.top().first)
        {
            pq.pop();
            pq.push(current_node);
        }
        if(square(p.pos[idm] - kd_tree[index].pos[idm]) < pq.top().first)
        {
            flag = 1;
        }
    }

    if(flag)
    {
        if(kd_tree[index].pos[idm] <= p.pos[idm])
            query_kdTree(p, closestM, K, lchild, depth+1);
        else
            query_kdTree(p, closestM, K, rchild, depth+1);
    }
}

int main()
{
    while(scanf("%d%hd",&n,&K)!=EOF)  
    {
        for(int i = 0; i < n; i++)
        {
            for(short j = 0; j < K; j++)
            {
                scanf("%hd", &input_v[i].pos[j]);
            }
        }
        //memset(kd_tree, 0, sizeof(point)*2*N);
        memset(kd_tree, 0, sizeof(point)*4*N);
        construct_kdTree(n, K, 0, n-1, 1, 0);

        scanf("%hd", &t);
        for(short c = 0; c < t; c++)
        {
            short closestM;
            point tmpv;
            for(short j = 0; j < K; j++)
            {
                scanf("%hd", &tmpv.pos[j]);
            }
            scanf("%hd", &closestM);
            query_kdTree(tmpv, closestM, K, 1, 0);

            printf("the closest %hd points are:\n", closestM);
            point pt[10];
            for(short i = 0; !pq.empty(); i++)
            {
                pt[i] = pq.top().second;
                pq.pop();
            }
            for(short i = closestM-1; i >= 0; i--)
            {
                for(short j = 0; j < K; j++)
                {
                    printf("%hd", pt[i].pos[j]);
                    if(j == K-1)
                        printf("\n");
                    else
                        printf(" ");
                }
            }
        }
    }
    return 0;
}
  
  
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
  • 96
  • 97
  • 98
  • 99
  • 100
  • 101
  • 102
  • 103
  • 104
  • 105
  • 106
  • 107
  • 108
  • 109
  • 110
  • 111
  • 112
  • 113
  • 114
  • 115
  • 116
  • 117
  • 118
  • 119
  • 120
  • 121
  • 122
  • 123
  • 124
  • 125
  • 126
  • 127
  • 128
  • 129
  • 130
  • 131
  • 132
  • 133
  • 134
  • 135
  • 136
  • 137
  • 138
  • 139
  • 140
  • 141
  • 142
  • 143
  • 144
  • 145
                    <link rel="stylesheet" href="https://csdnimg.cn/release/phoenix/template/css/markdown_views-ea0013b516.css">
                        </div>

猜你喜欢

转载自blog.csdn.net/jxy0123456789/article/details/80096489