节选自《Machine Learning in Action》——Peter Harrington
中文版是《机器学习实战》
本文介绍的是CART算法,用python实现,编译器为jupyter
1 复杂数据的局部性建模
决策树是一种贪心算法,它要在给定时间内作出最佳选择,但是并不关心能否达到全局最优,
树回归
- 优点:可以对复杂和非线性的数据建模
- 缺点:结果不易理解
上篇ID3算法 的缺点
- 切分过于迅速:每次选取当前最佳的特征来分割数据,并按照该特征所有可能取值来切分。也就是说,如果一个特征有4个取值,那么数据将被切成4份,该特征在之后的算法执行过程中将不会再起作用。
- 不能处理连续型特征
树回归的一般方法
- 收集数据:anyway
- 准备数据:标称型数据应该映射成二值型数据
- 分析数据:绘出数据的委会可视化显示结果,以字典的方式生成树
- 训练算法:大部分时间都花费在叶节点树模型的构建上
- 测试算法:使用测试数据上的R平方值来分析模型的效果
- 使用算法:使用训练出的树做预测模型,预测结果还可以用来做很多事情
2 连续和离散型特征的树的构建
用字典的形式存储树的数据结构,该字典包含以下4个元素
- 待切分的特征
- 待切分的特征值
- 右子树。当不再需要切分的时候,也可以是单个值
- 左子树。与右子树类似
from numpy import *
import regTrees
#四阶,单位矩阵
testMat = mat(eye(4))
#把第一列特征(0开始编号),按照大于0.5或者小于等于0.5分类
mat0,mat1 = regTrees.binSplitDataSet(testMat,1,0.5)
print (mat0,'\n')
print (mat1)
结果为
[[ 0. 1. 0. 0.]]
[[ 1. 0. 0. 0.]
[ 0. 0. 1. 0.]
[ 0. 0. 0. 1.]]
3 将CART算法用于回归
3.1 构建树
用createTree()来构建
伪代码大致如下:
找到最佳的待切分特征:
如果该节点不能再分,将改节点存为叶节点
执行二元切分
在右子树调用createTree()方法
在左子树调用createTree()方法
核心代码是chooseBestSplit()函数。给定某误差计算方法,该函数会找到数据集上最佳的二元切分方式。另外该函数还要确定什么时候停止切分,一旦停止切分会生成一个叶节点。
因此,chooseBestSplit()只需要完成两件事
1. 用最佳方式切分数据集
2. 生成相应的叶节点
伪代码如下
对每个特征:
对每个特征值:
将数据集切分成两份
计算切分的误差
如果当前误差小于最小误差,那么将当前切分设定为最佳切分并更新最小误差
返回最佳切分的特征和阈值
如果找不到一个好的二元切分,该函数返回None并同时调用createTree()方法来产生叶子节点
Note:此函数有三种情况不可切分
3.2 运行代码
from numpy import *
import regTrees
myDat = regTrees.loadDataSet('ex00.txt')
myMat = mat(myDat)
regTrees.createTree(myMat)
结果为
{'left': 1.0180967672413792,
'right': -0.044650285714285719,
'spInd': 0,
'spVal': 0.48813}
数据集为
0.203693 -0.064036
0.355688 -0.119399
0.988852 1.069062
0.518735 1.037179
0.514563 1.156648
0.976414 0.862911
0.919074 1.123413
0.697777 0.827805
0.928097 0.883225
0.900272 0.996871
0.344102 -0.061539
0.148049 0.204298
0.130052 -0.026167
0.302001 0.317135
0.337100 0.026332
0.314924 -0.001952
0.269681 -0.165971
0.196005 -0.048847
0.129061 0.305107
0.936783 1.026258
0.305540 -0.115991
0.683921 1.414382
0.622398 0.766330
0.902532 0.861601
0.712503 0.933490
0.590062 0.705531
0.723120 1.307248
0.188218 0.113685
0.643601 0.782552
0.520207 1.209557
0.233115 -0.348147
0.465625 -0.152940
0.884512 1.117833
0.663200 0.701634
0.268857 0.073447
0.729234 0.931956
0.429664 -0.188659
0.737189 1.200781
0.378595 -0.296094
0.930173 1.035645
0.774301 0.836763
0.273940 -0.085713
0.824442 1.082153
0.626011 0.840544
0.679390 1.307217
0.578252 0.921885
0.785541 1.165296
0.597409 0.974770
0.014083 -0.132525
0.663870 1.187129
0.552381 1.369630
0.683886 0.999985
0.210334 -0.006899
0.604529 1.212685
0.250744 0.046297
4 剪枝
4.1 预剪枝
第3节简单的实验结果还是挺满意的,但树的构建对输入的参数tolS和tolN非常敏感,如果使用其他值将不太容易达到这么好的效果。
chooseBestSplit()终止条件,实际上是在进行一种所谓的预剪枝(prepruning操作)。也即不断修改ops的参数,但这并不是一个好办法,因为我们有时候不知道到底需要寻找什么样的结果。
另一种形式的剪枝需要使用测试集和训练集,称为后剪枝(postpruning),这是一种更理想化的剪枝方法。
4.2 后剪枝
函数prune()伪代码如下
基于已有的树切分测试数据:
如果存在任一子集是一棵树,则在该子集递归剪枝过程
计算将当前两个叶节点合并后的误差
计算不合并的误差
如果合并会降低误差的话,就将叶节点合并
from numpy import *
import regTrees
myDat2 = regTrees.loadDataSet('ex2.txt')
myMat2 = mat(myDat2)
regTrees.createTree(myMat2)
结果有很多节点
实验代码如下
myTree = regTrees.createTree(myMat2,ops=(0,1))
myDatTest = regTrees.loadDataSet('ex2test.txt')
myMat2Test = mat(myDatTest)
regTrees.prune(myTree,myMat2Test)
结果发现,合并了许多节点
from numpy import *
def loadDataSet(fileName): #general function to parse tab -delimited floats
dataMat = [] #assume last column is target value
fr = open(fileName)
for line in fr.readlines():
curLine = line.strip().split('\t')
#fltLine = map(float,curLine) #map all elements to float()
fltLine = [float(item) for item in curLine]
dataMat.append(fltLine)
return dataMat
#形参:数据集集合、待切分的特征和该特征下的某个值
def binSplitDataSet(dataSet, feature, value):
mat0 = dataSet[nonzero(dataSet[:,feature] > value)[0],:]
mat1 = dataSet[nonzero(dataSet[:,feature] <= value)[0],:]
return mat0,mat1
def regLeaf(dataSet):#returns the value used for each leaf
return mean(dataSet[:,-1])#计算最后一列
def regErr(dataSet):
return var(dataSet[:,-1]) * shape(dataSet)[0]
def chooseBestSplit(dataSet, leafType=regLeaf, errType=regErr, ops=(1,4)):
tolS = ops[0]; tolN = ops[1]# 1 和 4
#tolS是容许的误差下降值,tolN是切分的最少样本数
#if all the target variables are the same value: quit and return value
#如果该数目是1,那么就不需要再切分而直接返回
if len(set(dataSet[:,-1].T.tolist()[0])) == 1: #exit condition 1
return None, leafType(dataSet)
m,n = shape(dataSet)
#the choice of the best feature is driven by Reduction in RSS error from mean
S = errType(dataSet)
bestS = inf; bestIndex = 0; bestValue = 0
for featIndex in range(n-1):
#for splitVal in set(dataSet[:,featIndex]):#集合的形式
for splitVal in set(dataSet[:,featIndex].T.A.tolist()[0]):
mat0, mat1 = binSplitDataSet(dataSet, featIndex, splitVal)
if (shape(mat0)[0] < tolN) or (shape(mat1)[0] < tolN): continue
newS = errType(mat0) + errType(mat1)
if newS < bestS:
bestIndex = featIndex
bestValue = splitVal
bestS = newS
#if the decrease (S-bestS) is less than a threshold don't do the split
#如果切分数据集后效果提升不够大,那么就不应进行切分操作而直接创建叶节点
if (S - bestS) < tolS:
return None, leafType(dataSet) #exit conditon 2
mat0, mat1 = binSplitDataSet(dataSet, bestIndex, bestValue)
if (shape(mat0)[0] < tolN) or (shape(mat1)[0] < tolN): #exit condition 3
#检查切分后子集的大小,如果子集大小小于tolN,那么也不应切分
return None, leafType(dataSet)
#返回切分特征和特征值
return bestIndex,bestValue#returns the best feature to split on
#and the value used for that split
#形参:数据集、建立叶节点的函数、误差计算函数、树构建所需要包含其他参数的元组
def createTree(dataSet, leafType=regLeaf, errType=regErr, ops=(1,4)):#assume dataSet is NumPy Mat so we can array filtering
feat, val = chooseBestSplit(dataSet, leafType, errType, ops)#choose the best split
if feat == None: return val #if the splitting hit a stop condition return val
retTree = {}
retTree['spInd'] = feat
retTree['spVal'] = val
lSet, rSet = binSplitDataSet(dataSet, feat, val)
retTree['left'] = createTree(lSet, leafType, errType, ops)
retTree['right'] = createTree(rSet, leafType, errType, ops)
return retTree
#回归剪枝函数
#测试输入变量是否是一棵树,返回布尔结果,换句话说,该函数用于判断当前处理的节点是否是叶节点
def isTree(obj):
return (type(obj).__name__=='dict')
#递归调用,从上往下遍历,直到叶子节点为止,如果找到两个叶子节点,则计算他们的平均值。该函数对树进行塌陷处理
def getMean(tree):
if isTree(tree['right']): tree['right'] = getMean(tree['right'])
if isTree(tree['left']): tree['left'] = getMean(tree['left'])
return (tree['left']+tree['right'])/2.0
def prune(tree, testData):
if shape(testData)[0] == 0: return getMean(tree) #if we have no test data collapse the tree
if (isTree(tree['right']) or isTree(tree['left'])):#if the branches are not trees try to prune them
lSet, rSet = binSplitDataSet(testData, tree['spInd'], tree['spVal'])
if isTree(tree['left']): tree['left'] = prune(tree['left'], lSet)
if isTree(tree['right']): tree['right'] = prune(tree['right'], rSet)
#if they are now both leafs, see if we can merge them
if not isTree(tree['left']) and not isTree(tree['right']):
lSet, rSet = binSplitDataSet(testData, tree['spInd'], tree['spVal'])
errorNoMerge = sum(power(lSet[:,-1] - tree['left'],2)) +\
sum(power(rSet[:,-1] - tree['right'],2))
treeMean = (tree['left']+tree['right'])/2.0
errorMerge = sum(power(testData[:,-1] - treeMean,2))
if errorMerge < errorNoMerge:
print ("merging")
return treeMean
else: return tree
else: return tree