KNN与kd树及其实现

本文主要讲K近邻算法(KNN),kd树的构造和搜索

1.KNN算法

KNN是基本的分类算法,采用多数表决的方式预测。

wAAACH5BAEKAAAALAAAAAABAAEAAAICRAEAOw==

算法很简单,下面举个栗子,并运行看看结果

以电影为例子,给出一个数据集,再预测一个电影是爱情片还是动作片。下面是数据集即,打斗镜头和接吻镜头是数据的特征维度,电影类别是实例的类别,对应上面算法的y

wAAACH5BAEKAAAALAAAAAABAAEAAAICRAEAOw==

给出一个电影(18,90),打斗镜头18次,接吻镜头90次,现在预测它的类别,吗么根据算法先计算与每个样本的距离。

wAAACH5BAEKAAAALAAAAAABAAEAAAICRAEAOw==

假如k=3,取与x最近的三个,可以看到两个是爱情片,一个是动作片,所以根据多数表决的方式,预测这部电影是爱情片。

Python代码:

import numpy as np
import operator

#创建数据集
def createDataset():
    #四组二维特征
    group = np.array([[5,115],[7,106],[56,11],[66,9]])
    #四组对应标签
    labels = ('动作片','动作片','爱情片','爱情片')
    return group,labels

#分类算法
def classify(intX,dataSet,labels,k):

    #numpy中shape[0]返回数组的行数,shape[1]返回列数
    dataSetSize = dataSet.shape[0]
    #将intX在横向重复dataSetSize次,纵向重复1次
    #例如intX=([1,2])--->([[1,2],[1,2],[1,2],[1,2]])便于后面计算
    diffMat = np.tile(intX,(dataSetSize,1))-dataSet     #变成dataSetSize行,1列
    #二维特征相减后乘方
    sqdifMax = diffMat**2
    #计算距离
    #axis=1表示矩阵中每一行向量相加
    seqDistances = sqdifMax.sum(axis=1)
    distances = seqDistances**0.5
    print ("distances:",distances)
    #返回distance中元素从小到大排序后的索引
    sortDistance = distances.argsort()
    print ("sortDistance:",sortDistance)
    classCount = {}   #构建一个字典,用于存储类别和次,最后取最多次数的那个
    for i in range(k):
        #取出前k个元素的类别
        voteLabel = labels[sortDistance[i]]
        print ("第%d个voteLabel=%s",i,voteLabel)
        classCount[voteLabel] = classCount.get(voteLabel,0)+1
    #dict.get(key,default=None),字典的get()方法,返回指定键的值,如果值不在字典中返回默认值。
    #计算类别次数
    #sorted()和sort()区别在于前者不改变原样本,生成新的排序列表
    #key=operator.itemgetter(1)根据字典的值进行排序
    #key=operator.itemgetter(0)根据字典的键进行排序
    #reverse降序排序字典
    sortedClassCount = sorted(classCount.items(),key = operator.itemgetter(1),reverse = True)
    #结果sortedClassCount = [('动作片', 2), ('爱情片', 1)]
    print ("sortedClassCount:",sortedClassCount)
    return sortedClassCount[0][0]


if __name__ == '__main__':
    group,labels = createDataset()
    test = [20,101]   #预测数据
    test_class = classify(test,group,labels,3)
    print (test_class)
wAAACH5BAEKAAAALAAAAAABAAEAAAICRAEAOw==

2.kd树的构造

当样本的维度大且数量多的时候,为了考虑效率,采用特殊的存储训练数据,比如kd tree。

构造kd树和构造二叉树很类似。下面介绍具体算法。

  1. 构造根节点,根节点对应于包含数据集T的k维空间的超矩形区域。选择第一维维坐标轴,以T中所有样本数据的第一维的中位数为切分点,将根节点对应的超矩形区域切分为两个子区域。切分由通过切分点并与第一维坐标轴垂直的超平面实现。由根节点生成深度为1的左右子结点:左子结点对应第一维坐标小于切分点的子区域,右子结点对应第一位坐标大于切分点的子区域,将落在切分超平面上的实例点保存在根节点。
  2. 重复:对深度为j的结点,选择第i维切分的坐标轴,i=(jmodk)+1。以该节点的区域中所有实例的第i维坐标的中位数为切分点,将该节点对应的超矩形区域分为两个子区域。切分由通过切点并与第i维坐标轴垂直的超平面实现。由该结点生成深度为i+1的左右子结点。
  3. 直到连个子区域没有实例存在时停止,从而形成kd树的区域划分。

将《统计学习方法》的例3.2用Python实现,构造kd树,并用前序遍历验证一遍。


import numpy as np

#结点类
class Node():
    def __init__(self, data, lchild=None, rchild=None):
        self.data = data
        self.lchild = lchild
        self.rchild = rchild

#Kd树
class KdTree():
    def __init__(self):
        self.KdTree=None

    def create(self, dataSet, depth):
        if (len(dataSet)>0):                #一定要判断   否则抛出异常
            m, n = np.shape(dataSet)        #m个样本   n 维
            midIndex = int(m/2)             #中间值的索引
            axis = depth % n                #切分的维度  Python是从0开始,不用加1
            SortedDataSet = self.sort(dataSet,axis)   #排序
            node = Node(SortedDataSet[midIndex])
            leftDataSet = SortedDataSet[:midIndex]
            rightDataSet = SortedDataSet[midIndex+1:]
            node.lchild = self.create(leftDataSet,depth+1)   #递归构建左子树
            node.rchild = self.create(rightDataSet,depth+1)  #递归构建右子树
            return node

    def sort(self, dataset, axis):
        SortedDataSet = dataset[:]      #不能破坏原样本
        m,n = np.shape(SortedDataSet)
        for i in range(m):
            for j in range(0,m-i-1):
                if (SortedDataSet[j][axis] > SortedDataSet[j+1][axis]):
                    temp = SortedDataSet[j]
                    SortedDataSet[j] = SortedDataSet[j+1]
                    SortedDataSet[j + 1] = temp
        #  print(SortedDataSet)
        return SortedDataSet

#前序遍历检验是否构造正确
    def preOrder(self,node):
        if node!=None:
            print("--%s--" %node.data)
            self.preOrder(node.lchild)
            self.preOrder(node.rchild)

    
if __name__=='__main__':
    x = [2,4.5]   #预测点
    dataSet=[[2,3],
        [5,4],
        [9,6],
        [4,7],
        [8,1],
        [7,2]]
    kdtree = KdTree()
    tree=kdtree.create(dataSet,0)
    kdtree.preOrder(tree)     #检验先序遍历是否和书本一致

3.kd树搜索(按照最近邻)

算法:

输入:已构造的kd树;目标点x;

输出:x的最近邻

 

  1. 在kd树中找出包含目标点的x的叶结点:从根节点出发,递归地向下访问kd树。若目标点x当前维的坐标小于切分点的坐标,则移动到左子结点,否则移动到右子结点。直到子结点为叶结点为止。
  2. 以此叶结点为“当前最近点”。
  3. 递归地向上回退,在每个结点进行以下操作:(a)如果该节点保存的实例点比当前最近点距离目标点更进,则以该实例点为“当前最近点”。(b)当前最近点一定存在于该节点的一个子结点对应的区域。检查该子结点的父节点的另一子结点对应的区域是否有更近的点。具体地,检查另一子结点对应的区域是否以目标点为球心,以目标点与“当前最近点的距离为半径的超球体相交”如果相交,可能在另一子结点对应的区域内存在距离更近的点,移动到另一个子节点。接着,递归地进行最近邻搜索。如果不相交,喜爱那个上回退。
  4. 退到根节点时,搜索结束,最后的“当前最近点”即为x的最近邻点。

搜索函数代码:

    def search(self, tree, x):
        self.nearestPoint = None
        self.nearestValue = 0

        def travel(node, depth=0):
            if node!=None:
                n = len(x)
                axis=depth % n
                if(node.data[axis]<=x[axis]):
                    travel(node.rchild,depth+1)
                else:
                    travel(node.lchild, depth+1)

                distNodeAndX = self.dist(node.data,x)
                if (self.nearestPoint==None):
                    self.nearestPoint=node.data
                    self.nearestValue=distNodeAndX
                elif(self.nearestValue>distNodeAndX):
                    self.nearestValue=distNodeAndX
                    self.nearestPoint=node.data

                if(abs(node.data[axis]-x[axis])<=self.nearestValue):
                    if(x[axis]<node.data[axis]):
                        travel(node.rchild,depth+1)
                    else:
                        travel(node.lchild,depth+1)
        travel(tree)
        return self.nearestPoint

    def dist(self, x1, x2):  # 欧式距离的计算
        return ((np.array(x1) - np.array(x2)) ** 2).sum() ** 0.5

这里的递归很多,这里返回到父节点是递归结束后就返回到父节点了,然后判断父节点区域是否相交,相交的话再判断另一个子结点。代码和二叉树搜索相似,缺点是递归多了,时间复杂度比较大。这里我也花了很长时间去弄懂,建议手动写一遍这个查找过程,然后对比代码就明白了。

发布了16 篇原创文章 · 获赞 3 · 访问量 745

猜你喜欢

转载自blog.csdn.net/FeNGQiHuALOVE/article/details/94437711