基于CART形成的回归树以及树的预剪枝和后剪枝,代码如下:
"""
机器学习-树回归(CART)
姓名:pcb
日期:2019.01.10
"""
from numpy import *
class treeNode():
def __init__(self,feat,val,right,left):
featureToSplitOn=feat
valueOfSplit=val
rightBranch=right
leftBranch=left
#加载数据
def loadDataSet(filename):
dataMat=[]
fr=open(filename)
for line in fr.readlines():
curLine=line.strip().split('\t')
#fltLine=list(map(float,curLine))
fltLine=[]
for i in curLine:
fltLine.append(float(i))
dataMat.append(fltLine)
return dataMat
#在给定特征和特征值的情况下,通过数组过滤的方式将上述数据切分得到两个子集
def binSplitDataSet(dataSet,feature,value):
"""
:param dataSet: 数据集
:param feature: 待切分的特征
:param value: 该特征的某个值
:return: 返回按照特征和特征值切分的两个子集
"""
mat0 = dataSet[nonzero(dataSet[:, feature] > value)[0], :]
mat1 = dataSet[nonzero(dataSet[:, feature] <= value)[0], :]
return mat0, mat1
#节点的模型生成函数
def regLeaf(dataSet):
return mean(dataSet[:,-1])
#平方误差估计函数
def regErr(dataSet):
return var(dataSet[:,-1])*shape(dataSet)[0]
#1.作用:用最佳的方式切分数据集和生成相应的叶节点
#2.给定某个误差计算方法该函数会找到数据集上的最佳二元切分方式
#3.确定停止切分,并形成一个叶节点
#4.目标:找到数据集切分的最佳位置(通过遍历所有特征及其可能的取值来找到使误差最小化的切分阈值。)
#5.函数的伪代码:
"""
对每个特征:
对每个特征值:
将数据切分成两份
计算切分的误差
如果当前的误差小于最小误差,那么将当前切分设定为最佳切分并更新最小误差
返回最佳切分的特征和阈值
"""
def chooseBestSplit(dataSet, leafType=regLeaf, errType=regErr, ops=(1,4)):
"""
:param dataSet: 数据集
:param leafType: 创建叶节点函数的引用
:param errType: 总方差计算函数的引用
:param ops: 用户定义参数构成的元组
:return:
"""
tolS = ops[0]; tolN = ops[1]
if len(set(dataSet[:,-1].T.tolist()[0])) == 1: #如果所有值相等则退出
return None, leafType(dataSet)
m,n = shape(dataSet)
S = errType(dataSet)
bestS = inf; bestIndex = 0; bestValue = 0
for featIndex in range(n-1):
for splitVal in set((dataSet[:,featIndex].T.A.tolist())[0]): #将特征值的第某列提出出来,编程列表,然后创建无序不重复元素集
mat0, mat1 = binSplitDataSet(dataSet, featIndex, splitVal)
if (shape(mat0)[0] < tolN) or (shape(mat1)[0] < tolN):
continue
newS = errType(mat0) + errType(mat1)
if newS < bestS:
bestIndex = featIndex
bestValue = splitVal
bestS = newS
if (S - bestS) < tolS:
return None, leafType(dataSet) #如果误差减少不大则退出
mat0, mat1 = binSplitDataSet(dataSet, bestIndex, bestValue)
if (shape(mat0)[0] < tolN) or (shape(mat1)[0] < tolN): #如果切分的数据集很小则退出
return None, leafType(dataSet)
return bestIndex,bestValue #返回切分的最佳特征和特征值
def createTree(dataSet,leafType=regLeaf,errType=regErr,ops=(1,4)):
feat ,val=chooseBestSplit(dataSet,leafType,errType,ops) #将数据集分成两个部分
if feat==None:
return val
retTree={}
retTree['spInd']=feat
retTree['spVal']=val
lSet,rSet=binSplitDataSet(dataSet,feat,val)
retTree['left']=createTree(lSet,leafType,errType,ops)
retTree['right']=createTree(rSet,leafType,errType,ops)
return retTree
#-----------回归树的后剪枝----------------------------------------
#判断当前处理的节点是否是叶子节点
def isTree(obj):
return(type(obj).__name__=='dict')
#1.该函数对树进行塌陷处理(返回树的平均值)
#2.从上到下遍历到叶节点为止,如果找到两个叶节点则计算他们的平均值
def getMean(tree):
if isTree(tree['right']):
tree['right']=getMean(tree['right'])
if isTree(tree['left']):
tree['left']=getMean(tree['right'])
return (tree['left']+tree['right'])/2.0
#伪代码
"""
基于已有的树切分测试数据集:
如果存在任一子集是一棵树,则在该子集递归剪枝过程
计算当前两个叶子节点合并后的误差
计算不合并的误差
如果合并降低误差的话,就将叶节点合并
"""
def prune(tree,testData):
"""
:param tree: 待剪枝的树
:param testData: 剪枝所需要的测试数据集
:return:
"""
#首先确认测试集非空
#一旦为空,则反复调用递归函数对测试数据集进行切分
if shape(testData)[0]==0:
return getMean(tree)
if (isTree(tree['left']))or isTree(tree['right']):
lSet,rSet=binSplitDataSet(testData,tree['spInd'],tree['spVal'])
if isTree(tree['left']):
tree['left']=prune(tree['left'],lSet)
if isTree(tree['right']):
tree['right']=prune(tree['right'],rSet)
if not isTree(tree['left']) and not isTree(tree['right']):
lSet,rSet=binSplitDataSet(testData,tree['spInd'],tree['spVal'])
errorNoMerge=sum(power(lSet[:,-1]-tree['left'],2))+\
sum(power(rSet[:,-1]-tree['right'],2))
treeMean=(tree['left']+tree['right'])/2.0
errorMergr=sum(power(testData[:,-1]-treeMean,2))
if errorMergr<errorNoMerge:
print("merging")
return treeMean
else:
return tree
else:
return tree
#----------------------------------------------------------------
def main():
# #1.-------------------------------------
# myDat=loadDataSet('ex0.txt')
# myMat=mat(myDat)
# myTree=createTree(myMat)
# print(myTree)
# #---------------------------------------
#2.----------决策树的后剪枝测试-------------
myDat2=loadDataSet('ex2.txt')
myMat2=mat(myDat2)
myTree=createTree(myMat2,ops=(100,1))
print(myTree)
myDatTest=loadDataSet('ex2test.txt')
myMat2Test=mat(myDatTest)
pruneTree=prune(myTree,myMat2Test)
print(pruneTree)
#-----------------------------------------
if __name__=='__main__':
main()