[] K-nearest neighbor statistical learning methods kd-tree python achieve

Disclaimer: This article is a blogger original article, follow the CC 4.0 BY-SA copyright agreement, reproduced, please attach the original source link and this statement.
This link: https://blog.csdn.net/tudaodiaozhale/article/details/77327003

Foreword

Code can be downloaded at Github: download

k-nearest neighbor may be easy to understand machine learning, a realization of the algorithm, "machine learning combat," the first chapter is to introduce it as to get started. And k nearest neighbor algorithm can be briefly distance measurement for each sample by traversing the data set, and to find the minimum distance k points. But this way, when a large number of samples once, it is easy to cause a lot of calculations.

So it is necessary to store data in a tree structure for quick retrieval, which is to be set forth herein kd-tree.

achieve

Is divided into two parts, one is the establishment of a kd-tree, is a kd-tree search.

 

kd-tree establishment

# --*-- coding:utf-8 --*--
import numpy as np

First define what character sets as well as packages.

First, let's implement a node class, used to represent kd.

class Node:
    def __init__(self, data, lchild = None, rchild = None):
        self.data = data
        self.lchild = lchild
        self.rchild = rchild

A node contains node domain, left child, right child. (If unfamiliar binary tree of knowledge, then I suggest look at some of the binary tree data structure, and preorder, preorder there after traversing the relevant code)

Binary Tree relevant code (C language)

Then create a code kd tree, mainly based on P41, 3.2 algorithm to achieve.

def create(self, dataSet, depth):   #创建kd树,返回根结点
        if (len(dataSet) > 0):
            m, n = np.shape(dataSet)    #求出样本行,列
            midIndex = m / 2 #中间数的索引位置
            axis = depth % n    #判断以哪个轴划分数据,对应书中算法3.2(2)公式j()
            sortedDataSet = self.sort(dataSet, axis) #进行排序
            node = Node(sortedDataSet[midIndex]) #将节点数据域设置为中位数,具体参考下书本
            # print sortedDataSet[midIndex]
            leftDataSet = sortedDataSet[: midIndex] #将中位数的左边创建2个副本
            rightDataSet = sortedDataSet[midIndex+1 :]
            print leftDataSet
            print rightDataSet
            node.lchild = self.create(leftDataSet, depth+1) #将中位数左边样本传入来递归创建树
            node.rchild = self.create(rightDataSet, depth+1)
            return node
        else:
            return None

Note the above code should be understood by looking at two, where required by the axis j (mod k) +1, i.e. [depth mod n (characteristic number) (depth) +1] median axes are then decides to insert node data to the left and right node. Note that the shaft is then divided by why the above formula [depth (depth) mod n (number of features)], because the python array index is zero-based.

def sort(self, dataSet, axis):  #采用冒泡排序,利用aixs作为轴进行划分
        sortDataSet = dataSet[:]    #由于不能破坏原样本,此处建立一个副本
        m, n = np.shape(sortDataSet)
        for i in range(m):
            for j in range(0, m - i - 1):
                if (sortDataSet[j][axis] > sortDataSet[j+1][axis]):
                    temp = sortDataSet[j]
                    sortDataSet[j] = sortDataSet[j+1]
                    sortDataSet[j+1] = temp
        print sortDataSet
        return sortDataSet

Create a tree when in order to find the median, need (one dimension) sorted axis to find the middle of that number. Here I used the bubble sort.

def preOrder(self, node):
        if node != None:
            print "tttt->%s" % node.data
            self.preOrder(node.lchild)
            self.preOrder(node.rchild)

Of course, I chose the preorder to create simple check under the tree there is no problem. (Look at whether the normal tree traversal, this step can be ignored)

 

kd tree search

    def search(self, tree, x):  #搜索
        self.nearestPoint = None    #保存最近的点
        self.nearestValue = 0   #保存最近的值
        def travel(node, depth = 0):    #递归搜索
            if node != None:    #递归终止条件
                n = len(x)  #特征数
                axis = depth % n    #计算轴
                if x[axis] < node.data[axis]:   #如果数据小于结点,则往左结点找
                    travel(node.lchild, depth+1)
                else:
                    travel(node.rchild, depth+1)

                #以下是递归完毕,对应算法3.3(3)
                distNodeAndX = self.dist(x, node.data)  #目标和节点的距离判断
                if (self.nearestPoint == None): #确定当前点,更新最近的点和最近的值,对应算法3.3(3)(a)
                    self.nearestPoint = node.data
                    self.nearestValue = distNodeAndX
                elif (self.nearestValue > distNodeAndX):
                    self.nearestPoint = node.data
                    self.nearestValue = distNodeAndX

                print(node.data, depth, self.nearestValue, node.data[axis], x[axis])
                if (abs(x[axis] - node.data[axis]) <= self.nearestValue):  #确定是否需要去子节点的区域去找(圆的判断),对应算法3.3(3)(b)
                    if x[axis] < node.data[axis]:
                        travel(node.rchild, depth+1)
                    else:
                        travel(node.lchild, depth + 1)
        travel(tree)
        return self.nearestPoint

    def dist(self, x1, x2): #欧式距离的计算
        return ((np.array(x1) - np.array(x2)) ** 2).sum() ** 0.5

Search tree too much trouble, first of all start with under the principle of it.

(1) identify the kd-tree leaf node contains the target point x: starting from the root, recursively for kd tree down. If the target point coordinate values of the current dimension smaller than the cut points, the move to the left child node, otherwise move to the right child node. Up until the child node is a leaf node;
(2) In this leaf node as the "current closest point";
(3) upward recursive backoff, the following operations for each node:
  (A) If the node examples saved point nearer than the nearest point from the current target point, places the example point "current closest point";
  (B) nearest the current point must exist in a child node of the node corresponding to the region. Check that the child node of the parent node to another child node corresponding to the region closer point. Specifically, the check node corresponding to the other sub-area is the center of the sphere at target point, the target point and the "current closest point" is the distance between the radius of the hypersphere intersect. If the intersection, the other may correspond to the child nodes in the memory region closer to the target point moves to another child node. Then, recursively perform nearest neighbor search. If you do not intersect, up rollback.
(4) When fall back to the root node, the search ends. Finally, the "current closest point" nearest neighbor point is x.

Attention , first by step to find the leaf node, and then back to the moon, when do two things, (a) is updated with the latest point, (b) is to check whether you need to check the parent node another node area of a node.

                if x[axis] < node.data[axis]:   #如果数据小于结点,则往左结点找
                    travel(node.lchild, depth+1)
                else:
                    travel(node.rchild, depth+1)

This is similar to a binary search tree of course, until you find the leaf nodes.

                #以下是递归完毕后,往父结点方向回朔,对应算法3.3(3)
                distNodeAndX = self.dist(x, node.data)  #目标和节点的距离判断
                if (self.nearestPoint == None): #确定当前点,更新最近的点和最近的值,对应算法3.3(3)(a)
                    self.nearestPoint = node.data
                    self.nearestValue = distNodeAndX
                elif (self.nearestValue > distNodeAndX):
                    self.nearestPoint = node.data
                    self.nearestValue = distNodeAndX

                print(node.data, depth, self.nearestValue, node.data[axis], x[axis])
                if (abs(x[axis] - node.data[axis]) <= self.nearestValue):  #确定是否需要去子节点的区域去找(圆的判断),对应算法3.3(3)(b)
                    if x[axis] < node.data[axis]:
                        travel(node.rchild, depth+1)
                    else:
                        travel(node.lchild, depth + 1)

This code is P43 algorithm 3.3 content (3).

(A) easy to implement, but (b) the principle of determining the target point and the nearest point of the radius of a circle (as books P44 Figure 3.5, the target point S and the current closest point D form a circle), if the parent node with the points of that line (i.e. the linear piece of the circle) is divided by the axes intersect.

Plainly, the formula is: | target value (by reading axis) - the parent node (by reading axis) | <most recent value (radius of a circle), where the y-axis is read in by P44 3.5 x axis of FIG. value, and then subtracting the straight piece of the y-axis intersecting values, to see if less than the radius.

Note: Comments there said node.data here do not know which node instructions. To note here is that this node is not a parent, but the current node. Here if you are not familiar with the binary tree data structure, it is not easy to get to this point. I can only say a little lower.

"It should be understood under a binary search tree process."

If found, the other nodes recursively once again just fine. Corresponding to the following code:

travel(node.rchild, depth+1)

Finally, on github posted all the code (if it is convenient to give trouble praise it, your support is my forward momentum), and then run the code (the code to run successfully in python3.5).

KNN (KDtree) download

Results output (5,4)

Guess you like

Origin blog.csdn.net/tudaodiaozhale/article/details/77327003