KD树的C++实现

理论介绍

kd树(K-dimension tree)是一种对k维空间中的实例点进行存储以便对其进行快速检索的树形数据结构。kd树是是一种二叉树,表示对k维空间的一个划分,构造kd树相当于不断地用垂直于坐标轴的超平面将K维空间切分,构成一系列的K维超矩形区域。kd树的每个结点对应于一个k维超矩形区域。利用kd树可以省去对大部分数据点的搜索,从而减少搜索的计算量。

  • kd树的算法步骤
    这里写图片描述
     
  • kd树可以较大的减少搜索空间,提高搜索效率,在knn,光线追踪这些需要大量搜索的算法中具有重要的应用前景。

算法实现

  • 使用C++ 对kd树进行了实现,没有使用模板,如有错误欢迎指出~
#pragma once
#include <iostream>
#include <vector>
#include <algorithm>
#include <fstream>
using namespace std;

int dimUsed;

// kd树的节点定义
struct KdNode
{
    KdNode * parent;
    KdNode * leftChild;
    KdNode * rightChild;
    vector<int> val;    //存储的数据
    int axis;       // 使用的轴
    KdNode(vector<int> data, int ax)
    {
        val = data;
        axis = ax;
        parent = NULL;
        leftChild = NULL;
        rightChild = NULL;
    }
};

// 用于以第n位进行比较,方便重载
bool cmp(vector<int> a, vector<int> b)
{
    if (a[dimUsed] < b[dimUsed])
        return true;
    return false;
}

ostream & operator<<(ostream & os, vector<int> vi)
{
    os << "(";
    for (int i = 0; i < vi.size(); i++)
        cout << vi[i] << ",";
    os << ")";
    return os;
}

// 导入数据,第一行是数据数量和维度,后面跟着num行dim维度的数据,每个一行
bool loadData(string filename, vector<vector<int> > &data)
{
    ifstream infs(filename);
    if (infs.is_open())
    {
        int num,dim;
        infs >> num>>dim;
        vector<int> d(dim);
        for (int i = 0; i < num; i++)
        {
            for (int j = 0; j < dim; j++)
                infs >> d[j];
            data.push_back(d);
        }
        return true;
    }
    return false;
}

// 计算N维向量距离
int disVector(vector<int> a, vector<int> b)
{
    int sum = 0;
    for (int i = 0; i < a.size(); i++)
        sum += (a[i] - b[i])*(a[i] - b[i]);
    return sum;
}

// kd 树的类
class KdTree
{
private:
    int dimension;
    vector<vector<int> > data;
    KdNode * root;
public:
    KdTree(vector<vector<int> > d, int dim)
    {
        dimension = dim;
        data = d;
    }
    void createTree()
    {
        // 递归建树
        root = createTreeNode(0, data.size()-1,0);
    }

    // create Kd Tree struct
    KdNode * createTreeNode(int left, int right,int dim)
    {
        if (right < left)
            return NULL;
        dimUsed = dim;
        // 按照k维进行排序
        sort(data.begin() + left, data.begin() + right+1, cmp);
        int mid = (left + right+1) / 2;
        KdNode * r = new KdNode(data[mid], dim);
        r->leftChild = createTreeNode(left, mid - 1, (dim + 1) % dimension);
        if (r->leftChild != NULL)
            r->leftChild->parent = r;
        r->rightChild = createTreeNode(mid + 1, right, (dim + 1) % dimension);
        if (r->rightChild != NULL)
            r->rightChild->parent = r;
        return r;
    }

    void printKdTree()
    {
        printKdTreeNode(root);
    }

    void printKdTreeNode(KdNode * r)
    {
        if (r == NULL)
            return;
        printKdTreeNode(r->leftChild);
        cout << r->val << "\t";
        printKdTreeNode(r->rightChild);
    }

    // 查找kd树
    KdNode * searchKdTree(vector<int> d)
    {
        int dim = 0,minDis = 10000000;
        KdNode * r = root;
        KdNode * tmp;
        while (r != NULL)
        {
            tmp = r;
            if (d[dim] < r->val[dim])                           
                r = r->leftChild;   
            else
                r = r->rightChild;
            dim = (dim + 1) % dimension;
        }
        // 找到属于的那个节点
        minDis = min(disVector(d, tmp->val),minDis);
        KdNode * nearNode = tmp;
        cout << endl<<"nearest node: "<<tmp->val << endl;
        // 退回到根节点
        while (tmp->parent != NULL)
        {
            tmp = tmp->parent;
            // 判断父节点是否更近,如果近,更新最近节点
            if (disVector(tmp->val, d) < minDis)
            {
                nearNode = tmp;
                minDis = disVector(tmp->val, d);
            }
            cout << "now parent node: " << tmp->val << endl;
            KdNode * son;
            // 判断当前轴与点的距离,如果小于minDis,则进行到另一半进行查找
            if (abs(tmp->val[tmp->axis] - d[tmp->axis]) < minDis)
            {
                if (tmp->val[tmp->axis] > d[tmp->axis])
                    son = tmp->rightChild;
                else
                    son = tmp->leftChild;
                searchKdTreeNode(d, minDis, nearNode, son);
            }
        }   
        // 对根节点的另外半边子树进行搜索
        /*if (abs(tmp->val[tmp->axis] - d[tmp->axis]) < minDis)
        {
            if (tmp->val[tmp->axis] > d[tmp->axis])
                tmp = tmp->rightChild;
            else
                tmp = tmp->leftChild;
            searchKdTreeNode(d, minDis, nearNode, tmp);
        }*/
        return nearNode;
    }

    // 查找当前节点下的最近点
    void searchKdTreeNode(vector<int> d,int &minDis,KdNode * &nearNode,KdNode * tmp)
    {
        // 递归终止
        if (tmp == NULL)
            return;
        cout << "now node: " << tmp->val << endl;
        // 判断当前节点是否小于
        if (disVector(tmp->val, d) < minDis)
        {
            minDis = disVector(tmp->val, d);
            nearNode = tmp;
        }
        // 如果轴与节点的距离小于minDis,则两个半边都需要搜索,否则只需要搜索一个半边
        if (abs(tmp->val[tmp->axis] - d[tmp->axis]) < minDis)
        {
            searchKdTreeNode(d, minDis, nearNode, tmp->leftChild);
            searchKdTreeNode(d, minDis, nearNode, tmp->rightChild);
        }
        else
        {
        // 选择搜索的一个半边
            if (tmp->val[tmp->axis] > d[tmp->axis])
                searchKdTreeNode(d, minDis, nearNode, tmp->leftChild);
            else
                searchKdTreeNode(d, minDis, nearNode, tmp->rightChild);
        }

    }

};

// 测试kd树
void testKdTree()
{
    vector<vector<int> > data;
    loadData("kd.txt", data);
    KdTree * kdtree = new KdTree(data, data[0].size());
    kdtree->createTree();
    kdtree->printKdTree();
    cout << endl;
    vector<int> vi(2);
    cin >> vi[0] >> vi[1];
    KdNode * r = kdtree->searchKdTree(vi);
    cout << r->val << endl;
}

/*
测试数据
6 2
2 3
5 4
9 6
4 7
8 1
7 2
*/

猜你喜欢

转载自blog.csdn.net/hu694028833/article/details/78166338