【python】决策树CART


节选自《Machine Learning in Action》——Peter Harrington
中文版是《机器学习实战》
本文介绍的是CART算法,用python实现,编译器为jupyter

1 复杂数据的局部性建模

决策树是一种贪心算法,它要在给定时间内作出最佳选择,但是并不关心能否达到全局最优,

树回归

  1. 优点:可以对复杂和非线性的数据建模
  2. 缺点:结果不易理解

上篇ID3算法 的缺点

  1. 切分过于迅速:每次选取当前最佳的特征来分割数据,并按照该特征所有可能取值来切分。也就是说,如果一个特征有4个取值,那么数据将被切成4份,该特征在之后的算法执行过程中将不会再起作用。
  2. 不能处理连续型特征

树回归的一般方法

  1. 收集数据:anyway
  2. 准备数据:标称型数据应该映射成二值型数据
  3. 分析数据:绘出数据的委会可视化显示结果,以字典的方式生成树
  4. 训练算法:大部分时间都花费在叶节点树模型的构建上
  5. 测试算法:使用测试数据上的R平方值来分析模型的效果
  6. 使用算法:使用训练出的树做预测模型,预测结果还可以用来做很多事情

2 连续和离散型特征的树的构建

用字典的形式存储树的数据结构,该字典包含以下4个元素

  1. 待切分的特征
  2. 待切分的特征值
  3. 右子树。当不再需要切分的时候,也可以是单个值
  4. 左子树。与右子树类似
from numpy import *
import regTrees
#四阶,单位矩阵
testMat = mat(eye(4))
#把第一列特征(0开始编号),按照大于0.5或者小于等于0.5分类
mat0,mat1 = regTrees.binSplitDataSet(testMat,1,0.5)
print (mat0,'\n')
print (mat1)

结果为

[[ 0.  1.  0.  0.]] 

[[ 1.  0.  0.  0.]
 [ 0.  0.  1.  0.]
 [ 0.  0.  0.  1.]]

3 将CART算法用于回归

3.1 构建树

用createTree()来构建
伪代码大致如下:

找到最佳的待切分特征:
  如果该节点不能再分,将改节点存为叶节点
  执行二元切分
  在右子树调用createTree()方法
  在左子树调用createTree()方法

核心代码是chooseBestSplit()函数。给定某误差计算方法,该函数会找到数据集上最佳的二元切分方式。另外该函数还要确定什么时候停止切分,一旦停止切分会生成一个叶节点。
因此,chooseBestSplit()只需要完成两件事

1. 用最佳方式切分数据集
2. 生成相应的叶节点

伪代码如下
对每个特征:

  对每个特征值:
    将数据集切分成两份
    计算切分的误差
    如果当前误差小于最小误差,那么将当前切分设定为最佳切分并更新最小误差
返回最佳切分的特征和阈值

如果找不到一个好的二元切分,该函数返回None并同时调用createTree()方法来产生叶子节点

Note:此函数有三种情况不可切分

3.2 运行代码

from numpy import *
import regTrees
myDat = regTrees.loadDataSet('ex00.txt')  
myMat = mat(myDat)  
regTrees.createTree(myMat)

结果为

{'left': 1.0180967672413792,
 'right': -0.044650285714285719,
 'spInd': 0,
 'spVal': 0.48813}

数据集为

0.203693    -0.064036
0.355688    -0.119399
0.988852    1.069062
0.518735    1.037179
0.514563    1.156648
0.976414    0.862911
0.919074    1.123413
0.697777    0.827805
0.928097    0.883225
0.900272    0.996871
0.344102    -0.061539
0.148049    0.204298
0.130052    -0.026167
0.302001    0.317135
0.337100    0.026332
0.314924    -0.001952
0.269681    -0.165971
0.196005    -0.048847
0.129061    0.305107
0.936783    1.026258
0.305540    -0.115991
0.683921    1.414382
0.622398    0.766330
0.902532    0.861601
0.712503    0.933490
0.590062    0.705531
0.723120    1.307248
0.188218    0.113685
0.643601    0.782552
0.520207    1.209557
0.233115    -0.348147
0.465625    -0.152940
0.884512    1.117833
0.663200    0.701634
0.268857    0.073447
0.729234    0.931956
0.429664    -0.188659
0.737189    1.200781
0.378595    -0.296094
0.930173    1.035645
0.774301    0.836763
0.273940    -0.085713
0.824442    1.082153
0.626011    0.840544
0.679390    1.307217
0.578252    0.921885
0.785541    1.165296
0.597409    0.974770
0.014083    -0.132525
0.663870    1.187129
0.552381    1.369630
0.683886    0.999985
0.210334    -0.006899
0.604529    1.212685
0.250744    0.046297

4 剪枝

4.1 预剪枝

第3节简单的实验结果还是挺满意的,但树的构建对输入的参数tolS和tolN非常敏感,如果使用其他值将不太容易达到这么好的效果。
chooseBestSplit()终止条件,实际上是在进行一种所谓的预剪枝(prepruning操作)。也即不断修改ops的参数,但这并不是一个好办法,因为我们有时候不知道到底需要寻找什么样的结果。

另一种形式的剪枝需要使用测试集和训练集,称为后剪枝(postpruning),这是一种更理想化的剪枝方法。

4.2 后剪枝

函数prune()伪代码如下
基于已有的树切分测试数据:
  如果存在任一子集是一棵树,则在该子集递归剪枝过程
  计算将当前两个叶节点合并后的误差
  计算不合并的误差
  如果合并会降低误差的话,就将叶节点合并

from numpy import *
import regTrees
myDat2 = regTrees.loadDataSet('ex2.txt')  
myMat2 = mat(myDat2)  
regTrees.createTree(myMat2)

结果有很多节点


实验代码如下

myTree = regTrees.createTree(myMat2,ops=(0,1))
myDatTest = regTrees.loadDataSet('ex2test.txt')
myMat2Test = mat(myDatTest)
regTrees.prune(myTree,myMat2Test)

结果发现,合并了许多节点

from numpy import *

def loadDataSet(fileName):      #general function to parse tab -delimited floats
    dataMat = []                #assume last column is target value
    fr = open(fileName)
    for line in fr.readlines():
        curLine = line.strip().split('\t')
        #fltLine = map(float,curLine) #map all elements to float()
        fltLine = [float(item) for item in curLine]
        dataMat.append(fltLine)
    return dataMat

#形参:数据集集合、待切分的特征和该特征下的某个值
def binSplitDataSet(dataSet, feature, value):
    mat0 = dataSet[nonzero(dataSet[:,feature] > value)[0],:]
    mat1 = dataSet[nonzero(dataSet[:,feature] <= value)[0],:]
    return mat0,mat1

def regLeaf(dataSet):#returns the value used for each leaf
    return mean(dataSet[:,-1])#计算最后一列

def regErr(dataSet):
    return var(dataSet[:,-1]) * shape(dataSet)[0]

def chooseBestSplit(dataSet, leafType=regLeaf, errType=regErr, ops=(1,4)):
    tolS = ops[0]; tolN = ops[1]# 1 和 4
    #tolS是容许的误差下降值,tolN是切分的最少样本数
    #if all the target variables are the same value: quit and return value

    #如果该数目是1,那么就不需要再切分而直接返回
    if len(set(dataSet[:,-1].T.tolist()[0])) == 1: #exit condition 1
        return None, leafType(dataSet)

    m,n = shape(dataSet)
    #the choice of the best feature is driven by Reduction in RSS error from mean
    S = errType(dataSet)
    bestS = inf; bestIndex = 0; bestValue = 0
    for featIndex in range(n-1):
        #for splitVal in set(dataSet[:,featIndex]):#集合的形式
        for splitVal in set(dataSet[:,featIndex].T.A.tolist()[0]):  
            mat0, mat1 = binSplitDataSet(dataSet, featIndex, splitVal)
            if (shape(mat0)[0] < tolN) or (shape(mat1)[0] < tolN): continue
            newS = errType(mat0) + errType(mat1)
            if newS < bestS: 
                bestIndex = featIndex
                bestValue = splitVal
                bestS = newS
    #if the decrease (S-bestS) is less than a threshold don't do the split

    #如果切分数据集后效果提升不够大,那么就不应进行切分操作而直接创建叶节点   
    if (S - bestS) < tolS: 
        return None, leafType(dataSet) #exit conditon 2

    mat0, mat1 = binSplitDataSet(dataSet, bestIndex, bestValue)
    if (shape(mat0)[0] < tolN) or (shape(mat1)[0] < tolN):  #exit condition 3
    #检查切分后子集的大小,如果子集大小小于tolN,那么也不应切分
        return None, leafType(dataSet)

    #返回切分特征和特征值
    return bestIndex,bestValue#returns the best feature to split on
                              #and the value used for that split

#形参:数据集、建立叶节点的函数、误差计算函数、树构建所需要包含其他参数的元组
def createTree(dataSet, leafType=regLeaf, errType=regErr, ops=(1,4)):#assume dataSet is NumPy Mat so we can array filtering
    feat, val = chooseBestSplit(dataSet, leafType, errType, ops)#choose the best split
    if feat == None: return val #if the splitting hit a stop condition return val
    retTree = {}
    retTree['spInd'] = feat
    retTree['spVal'] = val
    lSet, rSet = binSplitDataSet(dataSet, feat, val)
    retTree['left'] = createTree(lSet, leafType, errType, ops)
    retTree['right'] = createTree(rSet, leafType, errType, ops)
    return retTree  

#回归剪枝函数
#测试输入变量是否是一棵树,返回布尔结果,换句话说,该函数用于判断当前处理的节点是否是叶节点
def isTree(obj):
    return (type(obj).__name__=='dict')

#递归调用,从上往下遍历,直到叶子节点为止,如果找到两个叶子节点,则计算他们的平均值。该函数对树进行塌陷处理
def getMean(tree):
    if isTree(tree['right']): tree['right'] = getMean(tree['right'])
    if isTree(tree['left']): tree['left'] = getMean(tree['left'])
    return (tree['left']+tree['right'])/2.0

def prune(tree, testData):
    if shape(testData)[0] == 0: return getMean(tree) #if we have no test data collapse the tree
    if (isTree(tree['right']) or isTree(tree['left'])):#if the branches are not trees try to prune them
        lSet, rSet = binSplitDataSet(testData, tree['spInd'], tree['spVal'])
    if isTree(tree['left']): tree['left'] = prune(tree['left'], lSet)
    if isTree(tree['right']): tree['right'] =  prune(tree['right'], rSet)
    #if they are now both leafs, see if we can merge them
    if not isTree(tree['left']) and not isTree(tree['right']):
        lSet, rSet = binSplitDataSet(testData, tree['spInd'], tree['spVal'])
        errorNoMerge = sum(power(lSet[:,-1] - tree['left'],2)) +\
            sum(power(rSet[:,-1] - tree['right'],2))
        treeMean = (tree['left']+tree['right'])/2.0
        errorMerge = sum(power(testData[:,-1] - treeMean,2))
        if errorMerge < errorNoMerge: 
            print ("merging")
            return treeMean
        else: return tree
    else: return tree

猜你喜欢

转载自blog.csdn.net/bryant_meng/article/details/79462380