机器学习-树回归

基于CART形成的回归树以及树的预剪枝和后剪枝,代码如下:

"""
机器学习-树回归(CART)
姓名:pcb
日期:2019.01.10
"""
from numpy import *


class treeNode():
    def __init__(self,feat,val,right,left):
        featureToSplitOn=feat
        valueOfSplit=val
        rightBranch=right
        leftBranch=left

#加载数据
def loadDataSet(filename):
    dataMat=[]
    fr=open(filename)
    for line in fr.readlines():
        curLine=line.strip().split('\t')
        #fltLine=list(map(float,curLine))
        fltLine=[]
        for i in curLine:
            fltLine.append(float(i))
        dataMat.append(fltLine)
    return dataMat

#在给定特征和特征值的情况下,通过数组过滤的方式将上述数据切分得到两个子集
def binSplitDataSet(dataSet,feature,value):
    """
    :param dataSet: 数据集
    :param feature: 待切分的特征
    :param value:   该特征的某个值
    :return:        返回按照特征和特征值切分的两个子集
    """
    mat0 = dataSet[nonzero(dataSet[:, feature] > value)[0], :]
    mat1 = dataSet[nonzero(dataSet[:, feature] <= value)[0], :]
    return mat0, mat1

#节点的模型生成函数
def regLeaf(dataSet):
    return mean(dataSet[:,-1])

#平方误差估计函数
def regErr(dataSet):
    return var(dataSet[:,-1])*shape(dataSet)[0]

#1.作用:用最佳的方式切分数据集和生成相应的叶节点
#2.给定某个误差计算方法该函数会找到数据集上的最佳二元切分方式
#3.确定停止切分,并形成一个叶节点
#4.目标:找到数据集切分的最佳位置(通过遍历所有特征及其可能的取值来找到使误差最小化的切分阈值。)
#5.函数的伪代码:
"""
        对每个特征:
            对每个特征值:
                将数据切分成两份
                计算切分的误差
                如果当前的误差小于最小误差,那么将当前切分设定为最佳切分并更新最小误差
        返回最佳切分的特征和阈值

"""
def chooseBestSplit(dataSet, leafType=regLeaf, errType=regErr, ops=(1,4)):
    """
    :param dataSet:  数据集
    :param leafType: 创建叶节点函数的引用
    :param errType:  总方差计算函数的引用
    :param ops:      用户定义参数构成的元组
    :return:
    """
    tolS = ops[0]; tolN = ops[1]

    if len(set(dataSet[:,-1].T.tolist()[0])) == 1:  #如果所有值相等则退出
        return None, leafType(dataSet)

    m,n = shape(dataSet)
    S = errType(dataSet)
    bestS = inf; bestIndex = 0; bestValue = 0
    for featIndex in range(n-1):
        for splitVal in set((dataSet[:,featIndex].T.A.tolist())[0]):    #将特征值的第某列提出出来,编程列表,然后创建无序不重复元素集
            mat0, mat1 = binSplitDataSet(dataSet, featIndex, splitVal)
            if (shape(mat0)[0] < tolN) or (shape(mat1)[0] < tolN):
                continue

            newS = errType(mat0) + errType(mat1)
            if newS < bestS:
                bestIndex = featIndex
                bestValue = splitVal
                bestS = newS

    if (S - bestS) < tolS:
        return None, leafType(dataSet)                          #如果误差减少不大则退出

    mat0, mat1 = binSplitDataSet(dataSet, bestIndex, bestValue)
    if (shape(mat0)[0] < tolN) or (shape(mat1)[0] < tolN):      #如果切分的数据集很小则退出
        return None, leafType(dataSet)

    return bestIndex,bestValue                                  #返回切分的最佳特征和特征值



def createTree(dataSet,leafType=regLeaf,errType=regErr,ops=(1,4)):
    feat ,val=chooseBestSplit(dataSet,leafType,errType,ops)          #将数据集分成两个部分
    if feat==None:
        return val
    retTree={}
    retTree['spInd']=feat
    retTree['spVal']=val
    lSet,rSet=binSplitDataSet(dataSet,feat,val)
    retTree['left']=createTree(lSet,leafType,errType,ops)
    retTree['right']=createTree(rSet,leafType,errType,ops)
    return retTree


#-----------回归树的后剪枝----------------------------------------
#判断当前处理的节点是否是叶子节点
def isTree(obj):
    return(type(obj).__name__=='dict')

#1.该函数对树进行塌陷处理(返回树的平均值)
#2.从上到下遍历到叶节点为止,如果找到两个叶节点则计算他们的平均值
def getMean(tree):
    if isTree(tree['right']):
        tree['right']=getMean(tree['right'])
    if isTree(tree['left']):
        tree['left']=getMean(tree['right'])
    return (tree['left']+tree['right'])/2.0

#伪代码
"""
    基于已有的树切分测试数据集:
        如果存在任一子集是一棵树,则在该子集递归剪枝过程
        计算当前两个叶子节点合并后的误差
        计算不合并的误差
        如果合并降低误差的话,就将叶节点合并
"""

def prune(tree,testData):
    """
    :param tree:     待剪枝的树
    :param testData: 剪枝所需要的测试数据集
    :return:
    """
    #首先确认测试集非空
    #一旦为空,则反复调用递归函数对测试数据集进行切分
    if shape(testData)[0]==0:
        return getMean(tree)

    if (isTree(tree['left']))or isTree(tree['right']):
        lSet,rSet=binSplitDataSet(testData,tree['spInd'],tree['spVal'])


    if isTree(tree['left']):
        tree['left']=prune(tree['left'],lSet)
    if isTree(tree['right']):
        tree['right']=prune(tree['right'],rSet)
    if not isTree(tree['left']) and not isTree(tree['right']):
        lSet,rSet=binSplitDataSet(testData,tree['spInd'],tree['spVal'])
        errorNoMerge=sum(power(lSet[:,-1]-tree['left'],2))+\
                     sum(power(rSet[:,-1]-tree['right'],2))
        treeMean=(tree['left']+tree['right'])/2.0
        errorMergr=sum(power(testData[:,-1]-treeMean,2))
        if errorMergr<errorNoMerge:
            print("merging")
            return treeMean
        else:
            return tree
    else:
        return tree
#----------------------------------------------------------------
def main():
# #1.-------------------------------------
#     myDat=loadDataSet('ex0.txt')
#     myMat=mat(myDat)
#     myTree=createTree(myMat)
#     print(myTree)
# #---------------------------------------

#2.----------决策树的后剪枝测试-------------
    myDat2=loadDataSet('ex2.txt')
    myMat2=mat(myDat2)
    myTree=createTree(myMat2,ops=(100,1))
    print(myTree)
    myDatTest=loadDataSet('ex2test.txt')
    myMat2Test=mat(myDatTest)
    pruneTree=prune(myTree,myMat2Test)
    print(pruneTree)

#-----------------------------------------
if __name__=='__main__':
    main()


猜你喜欢

转载自blog.csdn.net/pcb931126/article/details/86478502