回归树的实现(3个核心函数)

def binSplitDataSet(dataSet,feature,value):
    bigIndex = dataSet[:,feature] > value
    smallIndex = dataSet[:,feature] <= value
    #print('bigIndex:',bigIndex)
    #print('smallIndex:',smallIndex)
    big = dataSet[nonzero(bigIndex)[0],:]
    small = dataSet[nonzero(smallIndex)[0],:]
    return small,big

第一个函数有三个参数:数据集合,待切分的特征和该特征的某个值,在给定特征和特征值的情况下,该函数通过数组过滤方式将上述数据集合切分得到两个子集并返回

def chooseBestSplit(dataSet,leafType=regLeaf,errType=regErr,ops=(1,4)):
   
    if len(set(dataSet[:,-1].T.A.tolist()[0]))==1:
        return None,leafType(dataSet)
    eS=ops[0]
    minSample=ops[1]
    bestError=errType(dataSet)
    originError=errType(dataSet)
    bestFeatIndex=0
    bestSplitValue=0
    m,n=shape(dataSet)
    for featIndex in range (n-1):
        for splitValue in set(dataSet[:,featIndex].T.A.tolist()[0]):
            mat0,mat1=binSplitDataSet(dataSet,featIndex,splitValue)
            if shape(mat0)[0]<minSample or shape(mat1)[0]<minSample:
                continue
            newS=errType(mat0)+errType(mat1)
            #判断newS是否是左右切分
            if newS < bestError:
                bestError=newS
                bestFeatIndex=featIndex
                bestSplitValue=splitValue
    #切分效果不佳
    if originError-bestError<eS:
        return None,leafType(dataSet)
    #可以切分
    mat0,mat1=binSplitDataSet(dataSet,bestFeatIndex,bestSplitValue)
    if shape(mat0)[0]<minSample or shape(mat1)[0]<minSample:
        return None,leafType(dataSet)
    #返回本次切分的最优特征和切分值
    return bestFeatIndex,bestSplitValue

chooseBestSplit()函数目标是找到数据集切分的最佳位置。
    它遍历所有的特征及其可能的取值来找到误差最小化的切分阈值。 
    
    对每个特征:
        对每个特征值: 
            将数据集切分成两份(小于该特征值的数据样本放在左子树,否则放在右子树)
            计算切分的误差
            如果当前误差小于当前最小误差,那么将当前切分设定为最佳切分并更新最小误差
    返回最佳切分的特征和阈值
    
    Args:
        dataSet   加载的原始数据集
        leafType  建立叶子点的函数
        errType   误差计算函数(求总方差)
        ops       [容许误差下降值,切分的最少样本数]。非常重要,因为它决定了决策树划分停止的threshold值,被称为预剪枝(prepruning),
        其实也就是用于控制函数的停止时机。
        之所以这样说,是因为它防止决策树的过拟合,所以当误差的下降值小于tolS,或划分后的集合size小于tolN时,选择停止继续划分。

    Returns:
        bestIndex feature的index坐标
        bestValue 切分的最优值

def createTree(dataSet,leafType=regLeaf,errType=regErr,ops=(1,4)):
    feat,val=chooseBestSplit(dataSet,leafType,errType,ops)
    if feat==None:
        return val
    
    retTree={}
    retTree['spInd']=feat
    retTree['spVal']=val
    #大于在右边,小于在左边,分为两个数据集
    lSet,rSet=binSplitDataSet(dataSet,feat,val)
    #递归的进行调用,在左右子树中继续递归生成树
    retTree['left']=createTree(lSet,leafType,errType,ops)
    retTree['right']=createTree(rSet,leafType,errType,ops)
    return retTree

步骤:找到最佳的待切分特征
如果该节点不能再分,将该节点存为叶节点
执行二元切分
在右子树调用createTree()方法
在左子树调用createTree()方法

Args:
        dataSet   加载的原始数据集
        leafType  建立叶子点的函数
        errType   误差计算函数(求总方差)
        ops       [容许误差下降值,切分的最少样本数]。非常重要,因为它决定了决策树划分停止的threshold值,被称为预剪枝

猜你喜欢

转载自blog.csdn.net/WJWFighting/article/details/81511250