机器学习—K近邻,KD树算法python实现

完整代码及数据集下载
K近邻算法不需要训练,通过搜索训练集,找到预测点与训练集中距离最近的k个实例,然后取k个实例中最多的类型作为预测点的类型。K近邻算法三个基本要素:K值的选择,距离度量,分类决策规则。
算法:
a.在训练集中找出距离预测点x最近的k个点记为集合N。
b.在N中采用多数表决的决策规则,选出最多的类别y作为x的类。

距离度量

闵可夫斯基(Minkowski)距离

Lp(xi,xj)=(l=1n|xlixlj|p)1p

当p=2时,为欧式距离
Lp(xi,xj)=(l=1n|xlixlj|2)12

当p=1时,为曼哈顿距离
Lp(xi,xj)=l=1n|xlixlj|

k值的选择

较小的k值,“学习”的近似误差会减小,估计误差会增大。特征空间将被划分为更多的子空间,模型会变得负责,会出现过拟合。
较大的k值,“学习”的近似误差会增大,估计误差会减小。与输入实例较远的训练实例也会对预测起作用,是预测发生错误。k值的增大意味着模型的简单,可能出现欠拟合。
通常采用交叉验证的方法确定合适的k值。
关于近似误差与估计误差可以看近似误差与估计误差理解

分类决策规则

可以将分类的损失函数看做0-1损失函数,则经验风险为误分类率,若想使误分类率减小,则选取k近邻中出现次数最多的类型。


KD树

在进行k近邻算法时,每一次预测需要对整个训练集进行遍历计算距离,如果训练数据集非常大,则每次计算非常耗时。下面通过kd树结构来减少计算距离的次数。
kd树是一种对k维空间的实例点进行存储以便快速检索的树形结构。kd树是二叉树,表示对k维空间的划分。构造树的过程相当于不断的用垂直于坐标轴的超平面将k维空间切分,构成一系列的k维超巨型域,kd树的每个节点对应于一个k维超巨型区域。

构造方法:构造跟节点,使根节点对应于k维空间包含所有实例点的超巨型区域;通过如下递归的方法,不断对k维空间进行切分,生成子节点。在超平面区域中选择一个坐标轴和在此坐标轴上的一个切分点,通常选取方差最大的轴,方差大意味着数据更分散蕴含信息更多,切分后收益更大,切分点选择实例点的中位数点。然后超平面通过切分点垂直于选定的坐标轴将超巨型区域分为左右两个子区域,此时样本实例也对应分到两个子区域中,递归上步直到子区域没有实例点时终止。
选中位数进行切分的树为平衡树,但是平衡的kd树搜索时效率未必最优。
这里写图片描述
构造算法

def createTree(数据集)
    选择坐标轴
    选择中位数作为切分点
    记录当前点的实例点
    将数据集分为两部分
    createTree(左数据集)
    createTree(右数据集)
def createKDTree(dataSet,depth):

    n = np.shape(dataSet)[0]
    treeNode = {}
    if n == 0:
        return None
    else:
        n,m = np.shape(dataSet)
        split_axis = depth % m
        depth += 1
        treeNode['split'] = split_axis
        dataSet = sorted(dataSet,key= lambda a: a[split_axis])
        num = n // 2 
        treeNode['median'] = dataSet[num]
        treeNode['left'] = createKDTree(dataSet[:num],depth)
        treeNode['right'] = createKDTree(dataSet[num+1:],depth)
        return treeNode

构建树的时间复杂度为 o(knlogn) ,k为实例的维数,n为实例的个数。
这里写图片描述

搜索算法:
(1) 在kd树中找出包含目标点xx的叶结点:从根结点出发,递归的向下访问kd树。若目标点当前维的坐标值小于切分点的坐标值,则移动到左子结点,否则移动到右子结点。直到子结点为叶结点为止;
(2) 以此叶结点为“当前最近点”;
(3) 递归的向上回退,在每个结点进行以下操作:
 (a) 如果该结点保存的实例点比当前最近点距目标点更近,则以该实例点为“当前最近点”;
 (b) 当前最近点一定存在于该结点一个子结点对应的区域。检查该子结点的父结点的另一个子结点对应的区域是否有更近的点。具体的,检查另一个子结点对应的区域是否与以目标点为球心、以目标点与“当前最近点”间的距离为半径的超球体相交。如果相交,可能在另一个子结点对应的区域内存在距离目标更近的点,移动到另一个子结点。接着,递归的进行最近邻搜索。如果不相交,向上回退。
(4) 当回退到根结点时,搜索结束。最后的“当前最近点”即为xx的最近邻点。
搜索算法:

def search()
    判断是否为叶子
        是的话 返回距离与最近点
    判断向左右哪边递归
    search()
    计算当前节点与预测点的距离
    更新距离与最近点
    是否与当前节点的另一个子树相交(通过当前切分轴与数据点的距离是否大于当前最小距离判断)
        如果不相交 返回距离与最近点
    search()另一边
    更新距离与最近点
    返回 最近距离与最近点
def searchTree(tree,data):
    k = len(data)
    if tree is None:
        return [0]*k,float('inf')
    split_axis = tree['split']
    median_point = tree['median']
    if data[split_axis] <= median_point[split_axis]:
        nearestPoint,nearestDistance = searchTree(tree['left'],data)
    else:
        nearestPoint,nearestDistance = searchTree(tree['right'],data)
    nowDistance = np.linalg.norm(data-median_point)   #the distance between data to current point
    if nowDistance < nearestDistance:
        nearestDistance = nowDistance
        nearestPoint = median_point.copy()
    splitDistance = abs(data[split_axis] - median_point[split_axis])# the distance between hyperplane
    if splitDistance > nearestDistance:
        return nearestPoint,nearestDistance
    else:
        if data[split_axis] <= median_point[split_axis]:
            nextTree = tree['right']
        else:
            nextTree = tree['left']
        nearPoint,nearDistanc = searchTree(nextTree,data)
        if nearDistanc < nearestDistance:
            nearestDistance = nearDistanc
            nearestPoint = nearPoint.copy()
        return nearestPoint,nearestDistance

完整代码及数据集下载

猜你喜欢

转载自blog.csdn.net/weixin_37895339/article/details/78809528