[python] decision tree CART


Excerpted from "Machine Learning in Action" - Peter Harrington
's Chinese version is "Machine Learning Practice"
This article introduces the CART algorithm, implemented in python, and the compiler is jupyter

1 Locality modeling of complex data

Decision tree is a greedy algorithm, it wants to make the best choice in a given time, but does not care whether it can reach the global optimum,

tree regression

  1. Advantages: Complex and nonlinear data can be modeled
  2. Disadvantage: Results are not easy to understand

Disadvantages of the ID3 algorithm in the previous article

  1. The segmentation is too fast: the current best feature is selected each time to segment the data, and it is segmented according to all possible values ​​of the feature. That is to say, if a feature has 4 values, then the data will be cut into 4 parts, and the feature will no longer work in the subsequent algorithm execution process.
  2. Cannot handle continuous features

General approach to tree regression

  1. Collect data: anyway
  2. Prepare data: nominal data should be mapped to binary data
  3. Analyzing data: Plotting the visualization results of the data, generating a tree in the form of a dictionary
  4. Training Algorithm: Most of the time is spent on building the leaf node tree model
  5. Test Algorithm: Use the R-squared value on the test data to analyze the effect of the model
  6. Use algorithm: Use the trained tree as a prediction model, and the prediction result can be used to do many things

2 Construction of trees of continuous and discrete features

The data structure of the tree is stored in the form of a dictionary, the dictionary contains the following 4 elements

  1. Features to be segmented
  2. Eigenvalues ​​to be segmented
  3. right subtree. When splitting is no longer required, it can also be a single value
  4. Left subtree. similar to the right subtree
from numpy import *
import regTrees
#四阶,单位矩阵
testMat = mat(eye(4))
#把第一列特征(0开始编号),按照大于0.5或者小于等于0.5分类
mat0,mat1 = regTrees.binSplitDataSet(testMat,1,0.5)
print (mat0,'\n')
print (mat1)

The result is

[[ 0.  1.  0.  0.]] 

[[ 1.  0.  0.  0.]
 [ 0.  0.  1.  0.]
 [ 0.  0.  0.  1.]]

3 Using the CART algorithm for regression

3.1 Building the tree

Using createTree() to build the
pseudo code is roughly as follows:

Find the best feature to be split:
  If the node cannot be split, save the node as a leaf node
  and perform binary splitting Call the createTree() method on
  the right subtree and call the createTree() method
  on the left subtree

The core code is the chooseBestSplit() function. Given an error calculation method, this function finds the best binary split on the dataset. In addition, the function also determines when to stop splitting, and once the splitting is stopped, a leaf node will be generated.
So chooseBestSplit() only needs to do two things

1. Split the dataset in the best way
2. Generate corresponding leaf nodes

The pseudocode is as follows
for each feature:

  For each eigenvalue:
    split the dataset into two parts
    Calculate the error of the segmentation
    If the current error is less than the minimum error, then set the current segmentation as the best segmentation and update the minimum error
to return the feature and threshold of the best segmentation

If a good binary split cannot be found, the function returns None and calls the createTree() method at the same time to generate leaf nodes

Note: This function is inseparable in three cases

3.2 Running the code

from numpy import *
import regTrees
myDat = regTrees.loadDataSet('ex00.txt')  
myMat = mat(myDat)  
regTrees.createTree(myMat)

The result is

{'left': 1.0180967672413792,
 'right': -0.044650285714285719,
 'spInd': 0,
 'spVal': 0.48813}

The dataset is

0.203693    -0.064036
0.355688    -0.119399
0.988852    1.069062
0.518735    1.037179
0.514563    1.156648
0.976414    0.862911
0.919074    1.123413
0.697777    0.827805
0.928097    0.883225
0.900272    0.996871
0.344102    -0.061539
0.148049    0.204298
0.130052    -0.026167
0.302001    0.317135
0.337100    0.026332
0.314924    -0.001952
0.269681    -0.165971
0.196005    -0.048847
0.129061    0.305107
0.936783    1.026258
0.305540    -0.115991
0.683921    1.414382
0.622398    0.766330
0.902532    0.861601
0.712503    0.933490
0.590062    0.705531
0.723120    1.307248
0.188218    0.113685
0.643601    0.782552
0.520207    1.209557
0.233115    -0.348147
0.465625    -0.152940
0.884512    1.117833
0.663200    0.701634
0.268857    0.073447
0.729234    0.931956
0.429664    -0.188659
0.737189    1.200781
0.378595    -0.296094
0.930173    1.035645
0.774301    0.836763
0.273940    -0.085713
0.824442    1.082153
0.626011    0.840544
0.679390    1.307217
0.578252    0.921885
0.785541    1.165296
0.597409    0.974770
0.014083    -0.132525
0.663870    1.187129
0.552381    1.369630
0.683886    0.999985
0.210334    -0.006899
0.604529    1.212685
0.250744    0.046297

4 Pruning

4.1 Pre-pruning

The simple experimental results in Section 3 are quite satisfactory, but the construction of the tree is very sensitive to the input parameters tolS and tolN, and it is not easy to achieve such good results if other values ​​are used.
The termination condition of chooseBestSplit() is actually a so-called prepruning operation. That is, constantly modifying the parameters of ops, but this is not a good way, because sometimes we don't know what kind of results we need to look for.

Another form of pruning requires the use of test and training sets, called postpruning, which is a more idealized approach to pruning.

4.2 Post-pruning

The pseudocode of the function prune() is as follows:
Based on the existing tree segmentation test data:
  if any subset is a tree, the error   calculation
  after merging the current two leaf nodes is calculated in the recursive pruning process of the subset.
Merged error
  If merging will reduce the error, merge the leaf nodes

from numpy import *
import regTrees
myDat2 = regTrees.loadDataSet('ex2.txt')  
myMat2 = mat(myDat2)  
regTrees.createTree(myMat2)

The result has many nodes


The experimental code is as follows

myTree = regTrees.createTree(myMat2,ops=(0,1))
myDatTest = regTrees.loadDataSet('ex2test.txt')
myMat2Test = mat(myDatTest)
regTrees.prune(myTree,myMat2Test)

It turns out that many nodes are merged

from numpy import *

def loadDataSet(fileName):      #general function to parse tab -delimited floats
    dataMat = []                #assume last column is target value
    fr = open(fileName)
    for line in fr.readlines():
        curLine = line.strip().split('\t')
        #fltLine = map(float,curLine) #map all elements to float()
        fltLine = [float(item) for item in curLine]
        dataMat.append(fltLine)
    return dataMat

#形参:数据集集合、待切分的特征和该特征下的某个值
def binSplitDataSet(dataSet, feature, value):
    mat0 = dataSet[nonzero(dataSet[:,feature] > value)[0],:]
    mat1 = dataSet[nonzero(dataSet[:,feature] <= value)[0],:]
    return mat0,mat1

def regLeaf(dataSet):#returns the value used for each leaf
    return mean(dataSet[:,-1])#计算最后一列

def regErr(dataSet):
    return var(dataSet[:,-1]) * shape(dataSet)[0]

def chooseBestSplit(dataSet, leafType=regLeaf, errType=regErr, ops=(1,4)):
    tolS = ops[0]; tolN = ops[1]# 1 和 4
    #tolS是容许的误差下降值,tolN是切分的最少样本数
    #if all the target variables are the same value: quit and return value

    #如果该数目是1,那么就不需要再切分而直接返回
    if len(set(dataSet[:,-1].T.tolist()[0])) == 1: #exit condition 1
        return None, leafType(dataSet)

    m,n = shape(dataSet)
    #the choice of the best feature is driven by Reduction in RSS error from mean
    S = errType(dataSet)
    bestS = inf; bestIndex = 0; bestValue = 0
    for featIndex in range(n-1):
        #for splitVal in set(dataSet[:,featIndex]):#集合的形式
        for splitVal in set(dataSet[:,featIndex].T.A.tolist()[0]):  
            mat0, mat1 = binSplitDataSet(dataSet, featIndex, splitVal)
            if (shape(mat0)[0] < tolN) or (shape(mat1)[0] < tolN): continue
            newS = errType(mat0) + errType(mat1)
            if newS < bestS: 
                bestIndex = featIndex
                bestValue = splitVal
                bestS = newS
    #if the decrease (S-bestS) is less than a threshold don't do the split

    #如果切分数据集后效果提升不够大,那么就不应进行切分操作而直接创建叶节点   
    if (S - bestS) < tolS: 
        return None, leafType(dataSet) #exit conditon 2

    mat0, mat1 = binSplitDataSet(dataSet, bestIndex, bestValue)
    if (shape(mat0)[0] < tolN) or (shape(mat1)[0] < tolN):  #exit condition 3
    #检查切分后子集的大小,如果子集大小小于tolN,那么也不应切分
        return None, leafType(dataSet)

    #返回切分特征和特征值
    return bestIndex,bestValue#returns the best feature to split on
                              #and the value used for that split

#形参:数据集、建立叶节点的函数、误差计算函数、树构建所需要包含其他参数的元组
def createTree(dataSet, leafType=regLeaf, errType=regErr, ops=(1,4)):#assume dataSet is NumPy Mat so we can array filtering
    feat, val = chooseBestSplit(dataSet, leafType, errType, ops)#choose the best split
    if feat == None: return val #if the splitting hit a stop condition return val
    retTree = {}
    retTree['spInd'] = feat
    retTree['spVal'] = val
    lSet, rSet = binSplitDataSet(dataSet, feat, val)
    retTree['left'] = createTree(lSet, leafType, errType, ops)
    retTree['right'] = createTree(rSet, leafType, errType, ops)
    return retTree  

#回归剪枝函数
#测试输入变量是否是一棵树,返回布尔结果,换句话说,该函数用于判断当前处理的节点是否是叶节点
def isTree(obj):
    return (type(obj).__name__=='dict')

#递归调用,从上往下遍历,直到叶子节点为止,如果找到两个叶子节点,则计算他们的平均值。该函数对树进行塌陷处理
def getMean(tree):
    if isTree(tree['right']): tree['right'] = getMean(tree['right'])
    if isTree(tree['left']): tree['left'] = getMean(tree['left'])
    return (tree['left']+tree['right'])/2.0

def prune(tree, testData):
    if shape(testData)[0] == 0: return getMean(tree) #if we have no test data collapse the tree
    if (isTree(tree['right']) or isTree(tree['left'])):#if the branches are not trees try to prune them
        lSet, rSet = binSplitDataSet(testData, tree['spInd'], tree['spVal'])
    if isTree(tree['left']): tree['left'] = prune(tree['left'], lSet)
    if isTree(tree['right']): tree['right'] =  prune(tree['right'], rSet)
    #if they are now both leafs, see if we can merge them
    if not isTree(tree['left']) and not isTree(tree['right']):
        lSet, rSet = binSplitDataSet(testData, tree['spInd'], tree['spVal'])
        errorNoMerge = sum(power(lSet[:,-1] - tree['left'],2)) +\
            sum(power(rSet[:,-1] - tree['right'],2))
        treeMean = (tree['left']+tree['right'])/2.0
        errorMerge = sum(power(testData[:,-1] - treeMean,2))
        if errorMerge < errorNoMerge: 
            print ("merging")
            return treeMean
        else: return tree
    else: return tree

Guess you like

Origin http://10.200.1.11:23101/article/api/json?id=326759629&siteId=291194637