Create a decision tree

Article arrangement:
(1) First analyze how to divide the data from the perspective of information theory
(2) Apply mathematical formulas to the actual data set
(3) Draw a decision tree
(4) A practical routine
(5) Experimental summary

1. What is a decision tree?

Decision tree, see literally. A tree that can make decisions.

Such as common spam detection (Naive Bayes is also OK) (Detect the domain address of the email sent, identify the information in the email, find the words that often appear in spam such as: discount, free, buy), and then make a decision whether it is spam or normal communication email.

Why choose decision tree
Advantages: The computational complexity is not high, and the output is easy to understand.
Disadvantages: easy to produce over-fitting.
Use scenarios: numerical and nominal data

2. Construction of Decision Tree

Process:
(1) Collect data: any method can be used (data set is provided at the end of the article, haha)
(2) Data preparation: tree construction algorithm only applies to nominal data, so numerical data needs to be discretized (similar to digital signal processing modulus Conversion amplitude quantification)
(3) Analyze the data: any method can be used, after the tree is constructed, it should be checked whether the graph meets expectations in time
(4) training algorithm: the data structure of the construction tree
(5) test algorithm: calculation using experience tree Error rate
(6) Use algorithm: write a blog that can run without error, hahaha

Using the ID3 algorithm, each time the data set is divided into a feature, which feature should be selected as the basis for our division?

2.1 Information gain

Classification basis: similar data is divided into one category. After each classification, the data in the same branch structure has higher similarity. And the removal of the features used as the basis for classification will make the data setThe degree of disorder decreases and the data becomes more tidy.

What is the basis for measuring the order of data?
The introduction of information theory
The change of information before and after the data set is divided is called information gain, and the feature with the highest information gain is the best basis for our classification.
The measure of aggregate information is called Shannon entropy or entropy. The
larger the entropy, the higher the chaos of the data, the more disorderly the data, and the lower the similarity.

Entropy is the expected value of information. If the transaction to be classified may be divided into multiple categories, the information of the symbol x(i) is defined as:
Insert picture description here
we need to calculate the entropy of the data set, so we only need to add up and sum:
Insert picture description here
so far, We can analyze how the entropy of the data set changes before and after the division.

Calculate the Shannon entropy of the dataset

from math import log
def calcShannonEnt(dataSet):
    numEntries = len(dataSet)
    labelCounts = {
    
    }
    for featVec in dataSet: #the the number of unique elements and their occurance
        currentLabel = featVec[-1]
        if currentLabel not in labelCounts.keys(): labelCounts[currentLabel] = 0
        labelCounts[currentLabel] += 1
    shannonEnt = 0.0
    for key in labelCounts:
        prob = float(labelCounts[key])/numEntries
        shannonEnt -= prob * log(prob, 2) #log base 2
    return shannonEnt

Insert picture description here
The higher the entropy, the more mixed data. We can add more categories to the data set and observe the changes in the entropy of the data set
Insert picture description here

2.2 Dividing the data set

Divide the data set according to the given features, the first parameter is the data set to be divided, the second parameter is the feature of the divided data set, and the third parameter is the feature value that needs to be returned

def splitDataSet(dataSet, axis, value):
    retDataSet = []
    for featVec in dataSet:
        if featVec[axis] == value:
            reducedFeatVec = featVec[:axis]     #chop out axis used for splitting
            reducedFeatVec.extend(featVec[axis+1:])
            retDataSet.append(reducedFeatVec)
    return retDataSet

Insert picture description here
Example: Divide the data set according to feature 0, and return data records with feature 0 as 1

2.3 Choosing the best data set division method

In the previous section, we learned how to measure the level of confusion in a data set. Our goal: to continuously divide the data set into small branches according to the internal laws of the data, until the leaf nodes are classified.

We will calculate the entropy once for the data set divided by each feature, and then judge whether the feature is the best data set division feature (that is, a decision process)

def chooseBestFeatureToSplit(dataSet):
    numFeatures = len(dataSet[0]) - 1      #the last column is used for the labels
    baseEntropy = calcShannonEnt(dataSet)
    bestInfoGain = 0.0; bestFeature = -1
    for i in range(numFeatures):        #iterate over all the features
        featList = [example[i] for example in dataSet]#create a list of all the examples of this feature
        uniqueVals = set(featList)       #get a set of unique values
        newEntropy = 0.0
        for value in uniqueVals:
            subDataSet = splitDataSet(dataSet, i, value)
            prob = len(subDataSet)/float(len(dataSet))
            newEntropy += prob * calcShannonEnt(subDataSet)
        infoGain = baseEntropy - newEntropy     #calculate the info gain; ie reduction in entropy
        if (infoGain > bestInfoGain):       #compare this to the best gain so far
            bestInfoGain = infoGain         #if better than current best, set to best
            bestFeature = i
    return bestFeature 

Here, each feature is used as the basis for dividing the data set, and then the entropy newEntropy of the divided data set is obtained . When the entropy of the new data set is the smallest, that is, when the maximum information gain is obtained before and after selecting the feature to divide the data set, The confusion of the data is reduced, and this feature is the best data set division feature.
Insert picture description here

2.4 Recursively build decision trees

The sign of the end of the above data set division is: each record is assigned to a leaf node, each time a feature is selected, the data set is divided into several smaller data sets, we only need to continue with the divided branch data set Use the chooseBestFeatureToSplit function to continue dividing the data set, and recursion can simply implement this process.
In actual operation, we can also specify the number of leaf nodes (that is, the number of categories, C4.5 and CART algorithm)

def createTree(dataSet, labels):
    classList = [example[-1] for example in dataSet]
    if classList.count(classList[0]) == len(classList):
        return classList[0]#stop splitting when all of the classes are equal
    if len(dataSet[0]) == 1: #stop splitting when there are no more features in dataSet
        return majorityCnt(classList)
    bestFeat = chooseBestFeatureToSplit(dataSet)
    bestFeatLabel = labels[bestFeat]
    myTree = {
    
    bestFeatLabel:{
    
    }}
    del(labels[bestFeat])
    featValues = [example[bestFeat] for example in dataSet]
    uniqueVals = set(featValues)
    for value in uniqueVals:
        subLabels = labels[:]       #copy all of labels, so trees don't mess up existing labels
        myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet, bestFeat, value), subLabels)
    return myTree

Here, the branch operation of the tree is simulated through the nesting of the dictionary.
Insert picture description here

3. Draw a decision tree

If you want to draw a good-looking tree, you need to calculate the depth and width of the tree. Using matplotlib functions, the process is too cumbersome. Here you can directly call ==treePlotter.createPlot(lensestree)== to draw a beautiful decision tree

'''
Created on Oct 14, 2010

@author: Peter Harrington
'''
import matplotlib.pyplot as plt

decisionNode = dict(boxstyle="sawtooth", fc="0.8")
leafNode = dict(boxstyle="round4", fc="0.8")
arrow_args = dict(arrowstyle="<-")

def getNumLeafs(myTree):
    numLeafs = 0
    firstStr = list(myTree)[0]
    secondDict = myTree[firstStr]
    for key in secondDict.keys():
        if type(secondDict[key]).__name__ == 'dict':#test to see if the nodes are dictonaires, if not they are leaf nodes
            numLeafs += getNumLeafs(secondDict[key])
        else: numLeafs += 1
    return numLeafs

def getTreeDepth(myTree):
    maxDepth = 0
    firstStr = list(myTree)[0]
    secondDict = myTree[firstStr]
    for key in secondDict.keys():
        if type(secondDict[key]).__name__ == 'dict':#test to see if the nodes are dictonaires, if not they are leaf nodes
            thisDepth = 1 + getTreeDepth(secondDict[key])
        else: thisDepth = 1
        if thisDepth > maxDepth: maxDepth = thisDepth
    return maxDepth

def plotNode(nodeTxt, centerPt, parentPt, nodeType):
    createPlot.ax1.annotate(nodeTxt, xy=parentPt, xycoords='axes fraction',
                            xytext=centerPt, textcoords='axes fraction',
                            va="center", ha="center", bbox=nodeType, arrowprops=arrow_args)

def plotMidText(cntrPt, parentPt, txtString):
    xMid = (parentPt[0]-cntrPt[0])/2.0 + cntrPt[0]
    yMid = (parentPt[1]-cntrPt[1])/2.0 + cntrPt[1]
    createPlot.ax1.text(xMid, yMid, txtString, va="center", ha="center", rotation=30)

def plotTree(myTree, parentPt, nodeTxt):#if the first key tells you what feat was split on
    numLeafs = getNumLeafs(myTree)  #this determines the x width of this tree
    depth = getTreeDepth(myTree)
    firstStr = list(myTree)[0]     #the text label for this node should be this
    cntrPt = (plotTree.xOff + (1.0 + float(numLeafs))/2.0/plotTree.totalW, plotTree.yOff)
    plotMidText(cntrPt, parentPt, nodeTxt)
    plotNode(firstStr, cntrPt, parentPt, decisionNode)
    secondDict = myTree[firstStr]
    plotTree.yOff = plotTree.yOff - 1.0/plotTree.totalD
    for key in secondDict.keys():
        if type(secondDict[key]).__name__ == 'dict':#test to see if the nodes are dictonaires, if not they are leaf nodes
            plotTree(secondDict[key], cntrPt, str(key))        #recursion
        else:   #it's a leaf node print the leaf node
            plotTree.xOff = plotTree.xOff + 1.0/plotTree.totalW
            plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode)
            plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key))
    plotTree.yOff = plotTree.yOff + 1.0/plotTree.totalD
#if you do get a dictonary you know it's a tree, and the first element will be another dict

def createPlot(inTree):
    fig = plt.figure(1, facecolor='white')
    fig.clf()
    axprops = dict(xticks=[], yticks=[])
    createPlot.ax1 = plt.subplot(111, frameon=False, **axprops)    #no ticks
    #createPlot.ax1 = plt.subplot(111, frameon=False) #ticks for demo puropses
    plotTree.totalW = float(getNumLeafs(inTree))
    plotTree.totalD = float(getTreeDepth(inTree))
    plotTree.xOff = -0.5/plotTree.totalW; plotTree.yOff = 1.0
    plotTree(inTree, (0.5, 1.0), '')
    plt.show()

#def createPlot():
#    fig = plt.figure(1, facecolor='white')
#    fig.clf()
#    createPlot.ax1 = plt.subplot(111, frameon=False) #ticks for demo puropses
#    plotNode('a decision node', (0.5, 0.1), (0.1, 0.5), decisionNode)
#    plotNode('a leaf node', (0.8, 0.1), (0.3, 0.8), leafNode)
#    plt.show()

def retrieveTree(i):
    listOfTrees = [{
    
    'no surfacing': {
    
    0: 'no', 1: {
    
    'flippers': {
    
    0: 'no', 1: 'yes'}}}},
                   {
    
    'no surfacing': {
    
    0: 'no', 1: {
    
    'flippers': {
    
    0: {
    
    'head': {
    
    0: 'no', 1: 'yes'}}, 1: 'no'}}}}
                  ]
    return listOfTrees,[i]

# createPlot(thisTree)

Test code

if __name__ == '__main__':
    mydat,labels=createDataSet()
    thistree = createTree(mydat,labels)
    treePlotter.createPlot(thistree)

Test pattern:

Insert picture description here

4. A practical routine

Use the decision tree to predict the type of contact lens. The
source code of the data set is available. If you need to comment and leave a message, you can also send me an email [email protected]

Legend of experimental results:

Insert picture description here

5. Experiment summary

The core of the decision tree is the decision-making process. We use the Shannon entropy of information theory to measure the degree of confusion of the divided data, and then obtain the basis for decision-making. The algorithm ID3 uses a single feature to divide until it is divided into leaf nodes. The C4.5 and CART algorithms are currently popular, and they will fill in the hole later, haha.
The over-matching problem of the decision tree can be realized by cutting the decision tree to improve the generalization ability of the algorithm, and fill the hole later, haha.

Guess you like

Origin blog.csdn.net/ca___0/article/details/109606671