def binSplitDataSet(dataSet,feature,value):
bigIndex = dataSet[:,feature] > value
smallIndex = dataSet[:,feature] <= value
#print('bigIndex:',bigIndex)
#print('smallIndex:',smallIndex)
big = dataSet[nonzero(bigIndex)[0],:]
small = dataSet[nonzero(smallIndex)[0],:]
return small,big
第一个函数有三个参数:数据集合,待切分的特征和该特征的某个值,在给定特征和特征值的情况下,该函数通过数组过滤方式将上述数据集合切分得到两个子集并返回
def chooseBestSplit(dataSet,leafType=regLeaf,errType=regErr,ops=(1,4)):
if len(set(dataSet[:,-1].T.A.tolist()[0]))==1:
return None,leafType(dataSet)
eS=ops[0]
minSample=ops[1]
bestError=errType(dataSet)
originError=errType(dataSet)
bestFeatIndex=0
bestSplitValue=0
m,n=shape(dataSet)
for featIndex in range (n-1):
for splitValue in set(dataSet[:,featIndex].T.A.tolist()[0]):
mat0,mat1=binSplitDataSet(dataSet,featIndex,splitValue)
if shape(mat0)[0]<minSample or shape(mat1)[0]<minSample:
continue
newS=errType(mat0)+errType(mat1)
#判断newS是否是左右切分
if newS < bestError:
bestError=newS
bestFeatIndex=featIndex
bestSplitValue=splitValue
#切分效果不佳
if originError-bestError<eS:
return None,leafType(dataSet)
#可以切分
mat0,mat1=binSplitDataSet(dataSet,bestFeatIndex,bestSplitValue)
if shape(mat0)[0]<minSample or shape(mat1)[0]<minSample:
return None,leafType(dataSet)
#返回本次切分的最优特征和切分值
return bestFeatIndex,bestSplitValue
chooseBestSplit()函数目标是找到数据集切分的最佳位置。
它遍历所有的特征及其可能的取值来找到误差最小化的切分阈值。
对每个特征:
对每个特征值:
将数据集切分成两份(小于该特征值的数据样本放在左子树,否则放在右子树)
计算切分的误差
如果当前误差小于当前最小误差,那么将当前切分设定为最佳切分并更新最小误差
返回最佳切分的特征和阈值
Args:
dataSet 加载的原始数据集
leafType 建立叶子点的函数
errType 误差计算函数(求总方差)
ops [容许误差下降值,切分的最少样本数]。非常重要,因为它决定了决策树划分停止的threshold值,被称为预剪枝(prepruning),
其实也就是用于控制函数的停止时机。
之所以这样说,是因为它防止决策树的过拟合,所以当误差的下降值小于tolS,或划分后的集合size小于tolN时,选择停止继续划分。
Returns:
bestIndex feature的index坐标
bestValue 切分的最优值
def createTree(dataSet,leafType=regLeaf,errType=regErr,ops=(1,4)):
feat,val=chooseBestSplit(dataSet,leafType,errType,ops)
if feat==None:
return val
retTree={}
retTree['spInd']=feat
retTree['spVal']=val
#大于在右边,小于在左边,分为两个数据集
lSet,rSet=binSplitDataSet(dataSet,feat,val)
#递归的进行调用,在左右子树中继续递归生成树
retTree['left']=createTree(lSet,leafType,errType,ops)
retTree['right']=createTree(rSet,leafType,errType,ops)
return retTree
步骤:找到最佳的待切分特征
如果该节点不能再分,将该节点存为叶节点
执行二元切分
在右子树调用createTree()方法
在左子树调用createTree()方法
Args:
dataSet 加载的原始数据集
leafType 建立叶子点的函数
errType 误差计算函数(求总方差)
ops [容许误差下降值,切分的最少样本数]。非常重要,因为它决定了决策树划分停止的threshold值,被称为预剪枝