K-D树的算法原理与C++实现

K-D树

K-D树介绍

1.kd树本身是树结构的一种,但在节点上又不同于一般树。k-d树是每个节点都拥有k个维度数据的二叉树。二叉树是指所有非叶子节点可以视作用一个超平面把空间分区成两个半空间,节点左边的子树代表在超平面左边的点,节点右边的子树代表在超平面右边的点。

2.超平面:每个节点都与k维中垂直于超平面的那一维的维度有关。例如3维下取超平面为平面YOZ,则按照x轴划分,所有x值小于指定值的节点都会出现在左子树,所有x值大于指定值的节点都会出现在右子树。这样,超平面可以用该x值来确定,其法线为x轴的单位向量

K-D树的构造

 1.超平面的选择:采用最大方差法,计算出每个维度上的方差,取最大方差所在维度为垂直于超平面的方向轴(不是作为超平面而是作为方向轴)。因为方差越大,说明数据点在该维度上越为分散,从而分割的效果越好。

2.节点的选择:中值选择法,在数据结构的知识中已知平衡二叉树的平均查询效率最好,所以在选择各子树根节点时应该尽量保证分割后的树仍然能达到或接近平衡状态。将方向轴上的数据进行排序,取中值点为子树根节点进行超平面分割。

3.对左右子树重复以上过程,直至树建立完成。

相关表达说明

split:在KD树中用来表示方向轴序号。惯性思维会使用诸如x,y,z这种方向轴,而在多维空间中则是用自然数0,1,2,3……来表示不同维度。而这也方便了数组或者std::vector使用下标来引用不同维度上的数据。

最终呈现效果如下图

KNN算法

算法介绍

KNN算法称为最邻近点算法。是在KD树基础上的一种查找算法,即给定一个查找点,通过一定的方法遍历KD树从而查得一个与该点最邻近的节点。

算法原理

1.向下遍历:在遍历KD树时,根据每个节点的方向轴数据来确定遍历顺序。从根节点开始,将查找点与根节点方向轴上的数据进行比较,小于则进入左子树查找,反之则进入右子树查找。重复此过程直到遍历到叶子节点。完成向下遍历操作。

2.向上回溯:显然,遍历得到的点很大可能不是最近点,仅仅是处于同一子空间而已,既无法保证其距离小于查找点到叶子节点父节点的距离,也无法保证其距离小于查找点到另一子空间中节点的距离。

所以要进行回溯。首先,向上回溯到叶子节点的父节点,计算查找点到该父节点的距离,若小于之前距离则把该父节点当作最邻近点,并更新最短距离。反之则不更新。然后,以查找点为圆心,到查询到的邻近点距离为半径(此时可能是叶子节点也可能是叶子节点的父节点)进行画圆。观察是否与叶子节点的父节点另一子空间相交。若相交,则进入另一子空间继续查询(重复遍历和回溯操作)。若不相交,则继续回溯并重复操作,直到回溯至根节点方止。

举例示范

     

在该树中查找点(2,4.5)

向下遍历:

1.先从(7,2)查找,x=7为超平面方向轴,因为x=2 < x=7,所以遍历到左空间中的(5,4)节点。

2.在(5,4)点上查找时,y = 4为超平面的方向轴,因为由于查找点为y值为4.5,4.5>4,因此进入右子空间查找到(4,7)。

3.形成搜索路径<(7,2),(5,4),(4,7)>,取(4,7)为当前最近邻点,计算其与目标查找点的距离为3.202。

向上回溯:

1.然后回溯到父节点(5,4),计算其与查找点之间的距离为3.041。3.041<3.202,所以最邻近点更新为(5,4),最短距离更新为3.041。

2.以(2,4.5)为圆心,以3.041为半径作圆,如图所示。可见该圆和y = 4超平面相交,所以需要进入(5,4)左子空间进行查找。此时需将(2,3)节点加入搜索路径中得<(7,2),(2,3)>。

3.回溯至(2,3)叶子节点,(2,3)距离(2,4.5)比(5,4)要近,所以最近邻点更新为(2,3),最近距离更新为1.5。

4.回溯至根节点(7,2),以(2,4.5)为圆心1.5为半径作圆,并不和x = 7分割超平面相交。

5.至此,搜索路径回溯完。返回最近邻点(2,3),最近距离1.5。

C++代码实现

 kdtree.h

#ifndef KDTREE_H_
#define KDTREE_H_

#include<cmath>
#include<algorithm>
#include<stack>
#include<vector>
#include<iostream>
template<typename T> class KdTree { //定义节点结构 //---------- struct kdNode { std::vector<T> vec; int splitAttribute; kdNode *lChild; kdNode *rChild; kdNode *parent; kdNode(std::vector<T> v = {}, int split = 0, kdNode *lCh = nullptr, kdNode *rCh = nullptr, kdNode *par = nullptr) :vec(v), splitAttribute(split), lChild(lCh), rChild(rCh), parent(par) {} }; public: KdTree() { root = nullptr; } kdNode *getRoot() { return root; } std::vector<T> getRootData() { return root->vec; } //嵌套型数据结构 //------------ KdTree(std::vector<std::vector<T>> &data) { root = createKdTree(data); } //转置矩阵 //------- std::vector<std::vector<T>> transpose(std::vector<std::vector<T>> &data) { int m = data.size(); int n = data[0].size(); std::vector<std::vector<T>> trans(n, std::vector<T>(m, 0)); for (int i = 0; i < n; i++) { for (int j = 0; j < m; j++) { trans[i][j] = data[j][i]; } } return trans; } //计算每个方向上的方差 //----------------- double getVariance(std::vector<T> &vec) { int n = vec.size(); double sum = 0; for (int i = 0; i < n; i++) { sum += vec[i]; } double avg = sum / n; sum = 0; for (int i = 0; i < n; i++) { sum += pow(vec[i] - avg, 2);//#include<cmath> } return sum / n; } //根据最大方差确定垂直于超平面的轴序号split attribute //----------------------------------------- int getSplitAttribute(std::vector<std::vector<T>> &data) { int size = data.size(); int splitAttribute = 0; double maxVar = getVariance(data[0]); for (int i = 1; i < size; i++) { double temp = getVariance(data[i]); if (temp > maxVar) { splitAttribute = i; maxVar = temp; } } return splitAttribute; } //查询中值 //------- T getSplitValue(std::vector<T> &vec) { std::sort(vec.begin(), vec.end()); return vec[vec.size() / 2]; } //计算2个k维点的距离 //--------------- static double getDistance(std::vector<T> &v1, std::vector<T> &v2) { double sum = 0; for (size_t i = 0; i < v1.size(); i++) { sum += pow(v1[i] - v2[i], 2); } return sqrt(sum); } //创建kd-tree //----------- kdNode *createKdTree(std::vector<std::vector<T>> &data) { //cout << "create_1" << endl; if (data.empty()) return nullptr; int n = data.size(); if (n == 1) { return new kdNode(data[0], -1); } //获得轴序号与值 //------------ std::vector<std::vector<T>> data_T = transpose(data); int splitAttribute = getSplitAttribute(data_T); int splitValue = getSplitValue(data_T[splitAttribute]); //分割数据空间:根据attribute和value //------------------------------ std::vector<std::vector<T>> left; std::vector<std::vector<T>> right; int flag = 0; kdNode *splitNode = nullptr; for (int i = 0; i < n; i++) { if (flag == 0 && data[i][splitAttribute] == splitValue) { splitNode = new kdNode(data[i]); splitNode->splitAttribute = splitAttribute; flag = 1; continue; } if (data[i][splitAttribute] <= splitValue) { left.push_back(data[i]); } else { right.push_back(data[i]); } } splitNode->lChild = createKdTree(left); splitNode->rChild = createKdTree(right); return splitNode; } //-----------------------------最邻近算法------------------------------------ //------------------------------------------------------------------------- //指定起点查询 //---------- std::vector<T> searchNearestNeighbor(std::vector<T> &target, kdNode *start) { std::vector<T> NN = { 0,0 };//给定一个初始值 std::stack<kdNode *> searchPath; kdNode *p = start; if (p != nullptr) { while (p->splitAttribute != -1) //-1是指已到达边缘点,没有分割属性 { searchPath.push(p); int splitAttribute = p->splitAttribute; if (target[splitAttribute] <= p->vec[splitAttribute]) { p = p->lChild; } else { p = p->rChild; } } NN = p->vec; } double mindist = KdTree::getDistance(target, NN); kdNode *curNode; double dist; std::vector<T> nn; while (!searchPath.empty()) { curNode = searchPath.top(); searchPath.pop(); dist = KdTree::getDistance(target, curNode->vec); if (dist < mindist) { mindist = dist; NN = curNode->vec; //判断以target为中心,以dist为半径的球是否和节点的超平面相交 if (curNode->vec[curNode->splitAttribute] >= target[curNode->splitAttribute] - dist && curNode->vec[curNode->splitAttribute] <= target[curNode->splitAttribute] + dist) { if (target[curNode->splitAttribute] > curNode->vec[curNode->splitAttribute]) { nn = searchNearestNeighbor(target, curNode->lChild); } else { nn = searchNearestNeighbor(target, curNode->rChild); } if (KdTree::getDistance(target, nn) < KdTree::getDistance(target, NN)) { NN = nn; } } } else { if (curNode->vec[curNode->splitAttribute] >= target[curNode->splitAttribute] - mindist && curNode->vec[curNode->splitAttribute] <= target[curNode->splitAttribute] + mindist) { if (target[curNode->splitAttribute] > curNode->vec[curNode->splitAttribute]) { nn = searchNearestNeighbor(target, curNode->lChild); } else { nn = searchNearestNeighbor(target, curNode->rChild); } if (KdTree::getDistance(target, nn) < KdTree::getDistance(target, NN)) { NN = nn; } } } } return NN; } //从根节点进行查询 //------------- std::vector<T> searchNearestNeighbor(std::vector<T> &target) { std::vector<T> NN; NN = searchNearestNeighbor(target, root); return NN; } //打印kdTree //---------- void printTree(kdNode *root) { std::cout << "["; if (root->lChild) { std::cout << "left:"; printTree(root->lChild); } if (root) { std::cout << "("; for (size_t i = 0; i < root->vec.size(); i++) { std::cout << root->vec[i]; if (i != (root->vec.size() - 1)) std::cout << ","; } std::cout << ")"; } if (root->rChild) { std::cout << "right"; printTree(root->rChild); } std::cout << "]"; } private: kdNode * root; }; #endif // !KDTREE_H_

main.cpp

#include"KDTree.h"
using std::vector;
using std::cout;
using std::endl;
int main()
{
    double data[6][2] = { { 2,3 },{ 5,4 },{ 9,6 },{ 4,7 },{ 8,1 },{ 7,2 } };
    vector<vector<double>> train(6, vector<double>(2, 0));
    for (unsigned int i = 0; i < 6; i++)
    {
        for (unsigned int j = 0; j < 2; j++)
        {
            train[i][j] = data[i][j];
        }
    }
    KdTree<double> *Tree = new KdTree<double>(train);  

    //输出整棵树
    Tree->printTree(Tree->getRoot());
    cout << endl;
    cout << endl;

    //输出根节点
    vector<double> root = Tree->getRootData();
    vector<double>::iterator r = root.begin();
    cout << "root=";
    while (r != root.end())
        cout << *r++ << ",";

    //查找最近点
    cout << endl;
    cout << endl;
    vector<double> goal;
    double i, j;
    i = 9.0;
    j = 5.0;
    goal.push_back(i);
    goal.push_back(j);
    vector<double> nearestNeighbor = Tree->searchNearestNeighbor(goal);
    vector<double>::iterator beg = nearestNeighbor.begin();
    cout << endl;
    cout << "(" << i << "," << j << ") nearest neighbor is: ";
    while (beg != nearestNeighbor.end())
        cout << *beg++ << ",";
    cout << endl;
    return 0;
}

 代码初步测试已通过,还在不断优化中,若有错误或者更好的解决方案,非常感谢大家能够留言提出。

参考网站:

https://zh.wikipedia.org/wiki/K-d树

https://baike.baidu.com/item/kd-tree/2302515?fr=aladdin

https://www.cnblogs.com/earendil/p/8135074.html

https://www.cnblogs.com/wxquare/p/6497302.html

https://www.cnblogs.com/90zeng/p/kdtree.html

猜你喜欢

转载自www.cnblogs.com/jingrui/p/10469601.html