8.4 kd树

转载博客:https://www.cnblogs.com/LCcnblogs/p/6169136.html
kd树和knn算法的c语言实现
  基于kd树的knn的实现原理可以参考文末的链接,都是一些好文章。

这里参考了别人的代码。用c语言写的包括kd树的构建与查找k近邻的程序。

#include<stdio.h>
#include<stdlib.h>
#include<math.h>
#include<time.h>

typedef struct{
    
    //数据维度
    double x;
    double y;
}data_struct;

typedef struct kd_node{
    
    
    data_struct split_data;//数据结点
    int split;//分裂维
    struct kd_node *left;//由位于该结点分割超面左子空间内所有数据点构成的kd-tree
    struct kd_node *right;//由位于该结点分割超面右子空间内所有数据点构成的kd-tree
}kd_struct;

//用于排序
int cmp1( const void *a , const void *b )
{
    
    
    return (*(data_struct *)a).x > (*(data_struct *)b).x ? 1:-1;
}
//用于排序
int cmp2( const void *a , const void *b )
{
    
    
    return (*(data_struct *)a).y > (*(data_struct *)b).y ? 1:-1;
}
//计算分裂维和分裂结点
void choose_split(data_struct data_set[],int size,int dimension,int *split,data_struct *split_data)
{
    
    
    int i;
    data_struct *data_temp;
    data_temp=(data_struct *)malloc(size*sizeof(data_struct));
    for(i=0;i<size;i++)
        data_temp[i]=data_set[i];
    static int count=0;//设为静态
    *split=(count++)%dimension;//分裂维
    if((*split)==0) qsort(data_temp,size,sizeof(data_temp[0]),cmp1);
    else qsort(data_temp,size,sizeof(data_temp[0]),cmp2);
    *split_data=data_temp[(size-1)/2];//分裂结点排在中位
}
//判断两个数据点是否相等
int equal(data_struct a,data_struct b){
    
    
    if(a.x==b.x && a.y==b.y)    return 1;
    else    return 0;
}
//建立KD树
kd_struct *build_kdtree(data_struct data_set[],int size,int dimension,kd_struct *T)
{
    
    
    if(size==0) return NULL;//递归出口
    else{
    
    
        int sizeleft=0,sizeright=0;
        int i,split;
        data_struct split_data;
        choose_split(data_set,size,dimension,&split,&split_data);
        data_struct data_right[size];
        data_struct data_left[size];

        if (split==0){
    
    //x维
            for(i=0;i<size;++i){
    
    
                if(!equal(data_set[i],split_data) && data_set[i].x <= split_data.x){
    
    //比分裂结点小
                    data_left[sizeleft].x=data_set[i].x;
                    data_left[sizeleft].y=data_set[i].y;
                    sizeleft++;//位于分裂结点的左子空间的结点数
                }
                else if(!equal(data_set[i],split_data) && data_set[i].x > split_data.x){
    
    //比分裂结点大
                    data_right[sizeright].x=data_set[i].x;
                    data_right[sizeright].y=data_set[i].y;
                    sizeright++;//位于分裂结点的右子空间的结点数
                }
            }
        }
        else{
    
    //y维
            for(i=0;i<size;++i){
    
    
                if(!equal(data_set[i],split_data) && data_set[i].y <= split_data.y){
    
    
                    data_left[sizeleft].x=data_set[i].x;
                    data_left[sizeleft].y=data_set[i].y;
                    sizeleft++;
                }
                else if (!equal(data_set[i],split_data) && data_set[i].y > split_data.y){
    
    
                    data_right[sizeright].x = data_set[i].x;
                    data_right[sizeright].y = data_set[i].y;
                    sizeright++;
                }
            }
        }
        T=(kd_struct *)malloc(sizeof(kd_struct));
        T->split_data.x=split_data.x;
        T->split_data.y=split_data.y;
        T->split=split;
        T->left=build_kdtree(data_left,sizeleft,dimension,T->left);//左子空间
        T->right=build_kdtree(data_right,sizeright,dimension,T->right);//右子空间
        return T;//返回指针
    }
}
//计算欧氏距离
double compute_distance(data_struct a,data_struct b){
    
    
    double tmp=pow(a.x-b.x,2.0)+pow(a.y-b.y,2.0);
    return sqrt(tmp);
}
//搜索1近邻
void search_nearest(kd_struct *T,int size,data_struct test,data_struct *nearest_point,double *distance)
{
    
    
    int path_size;//搜索路径内的指针数目
    kd_struct *search_path[size];//搜索路径保存各结点的指针
    kd_struct* psearch=T;
    data_struct nearest;//最近邻的结点
    double dist;//查询结点与最近邻结点的距离
    search_path[0]=psearch;//初始化搜索路径
    path_size=1;
    while(psearch->left!=NULL || psearch->right!=NULL){
    
    
        if (psearch->split==0){
    
    
            if(test.x <= psearch->split_data.x)//如果小于就进入左子树
                psearch=psearch->left;
            else
                psearch=psearch->right;
        }
        else{
    
    
            if(test.y <= psearch->split_data.y)//如果小于就进入右子树
                psearch=psearch->left;
            else
                psearch=psearch->right;
        }
        search_path[path_size++]=psearch;//将经过的分裂结点保存在搜索路径中
    }
    //取出search_path最后一个元素,即叶子结点赋给nearest
    nearest.x=search_path[path_size-1]->split_data.x;
    nearest.y=search_path[path_size-1]->split_data.y;
    path_size--;//search_path的指针数减一
    dist=compute_distance(nearest,test);//计算与该叶子结点的距离作为初始距离

    //回溯搜索路径
    kd_struct* pback;
    while(path_size!=0){
    
    
        pback=search_path[path_size-1];//取出search_path最后一个结点赋给pback
        path_size--;//search_path的指针数减一

        if(pback->left==NULL && pback->right==NULL){
    
    //如果pback为叶子结点
            if(dist>compute_distance(pback->split_data,test)){
    
    
                nearest=pback->split_data;
                dist=compute_distance(pback->split_data,test);
            }
        }
        else{
    
    //如果pback为分裂结点
            int s=pback->split;
            if(s==0){
    
    //x维
                if(fabs(pback->split_data.x-test.x)<dist){
    
    //若以查询点为中心的圆(球或超球),半径为dist的圆与分割超平面相交,那么就要跳到另一边的子空间去搜索
                    if(dist>compute_distance(pback->split_data,test)){
    
    
                        nearest=pback->split_data;
                        dist=compute_distance(pback->split_data, test);
                    }
                    if(test.x<=pback->split_data.x)//若查询点位于pback的左子空间,那么就要跳到右子空间去搜索
                        psearch=pback->right;
                    else
                        psearch=pback->left;//若以查询点位于pback的右子空间,那么就要跳到左子空间去搜索
                    if(psearch!=NULL)
                        search_path[path_size++]=psearch;//psearch加入到search_path中
                }
            }
            else {
    
    //y维
                if(fabs(pback->split_data.y-test.y)<dist){
    
    //若以查询点为中心的圆(球或超球),半径为dist的圆与分割超平面相交,那么就要跳到另一边的子空间去搜索
                    if(dist>compute_distance(pback->split_data,test)){
    
    
                        nearest=pback->split_data;
                        dist=compute_distance(pback->split_data,test);
                    }
                    if(test.y<=pback->split_data.y)//若查询点位于pback的左子空间,那么就要跳到右子空间去搜索
                        psearch=pback->right;
                    else
                        psearch=pback->left;//若查询点位于pback的的右子空间,那么就要跳到左子空间去搜索
                    if(psearch!=NULL)
                        search_path[path_size++]=psearch;//psearch加入到search_path中
                }
            }
        }
    }

    (*nearest_point).x=nearest.x;//最近邻
    (*nearest_point).y=nearest.y;
    *distance=dist;//距离
}

int main()
{
    
    
    int n=6;//数据个数
    data_struct nearest_point;
    double distance;
    kd_struct *root=NULL;
    data_struct data_set[6]={
    
    {
    
    2,3},{
    
    5,4},{
    
    9,6},{
    
    4,7},{
    
    8,1},{
    
    7,2}};//数据集
    data_struct test={
    
    7.1,2.1};//查询点
    root=build_kdtree(data_set,n,2,root);

    search_nearest(root,n,test,&nearest_point,&distance);
    printf("nearest neighbor:(%.2f,%.2f)\ndistance:%.2f \n",nearest_point.x,nearest_point.y,distance);
    return 0;
}
/*                    x          5,4
                                /    \
                      y       2,3    7.2
                                \    /  \
                      x        4,7  8.1 9.6
*/
 

参考:

https://www.joinquant.com/post/2627?f=study&m=math

https://www.joinquant.com/post/2843?f=study&m=math

http://blog.csdn.net/zhl30041839/article/details/9277807

猜你喜欢

转载自blog.csdn.net/ZXG20000/article/details/114766313
8.4