[Artificial Intelligence and Machine Learning] Decision tree ID3 and its python implementation

1 Decision tree algorithm

Decision Tree is a common machine learning method and a very commonly used classification method. It is a kind of supervised learning. Common decision tree algorithms include ID3, C4.5, C5.0 and CART (classification and regression tree). The classification effect of CART is generally better than other decision trees.

Decision trees make decisions based on a tree structure. Generally, a decision tree contains a root node, several internal nodes and several leaf nodes.

Each internal node represents a judgment on an attribute.
Each branch represents the output of a judgment result.
Each leaf node represents a classification result.
The root node contains a sample ensemble.
The purpose of decision tree learning is to generate a decision tree with strong generalization ability, that is, a strong ability to deal with unseen examples. The basic process follows a simple and intuitive "divide-and-conquer" strategy. .

This article mainly introduces the ID3 algorithm. The core of the ID3 algorithm is to select features for partitioning based on information gain, and then recursively build a decision tree.

1.1 Feature selection

Feature selection is to select the optimal dividing attribute, and select a feature from the features of the current data as the dividing standard for the current node. As the division process continues, it is hoped that the samples contained in the branch nodes of the decision tree belong to the same category as much as possible, that is, the "purity" of the nodes is getting higher and higher.

1.2 Entropy

Entropy represents the degree of transaction uncertainty, that is, the amount of information (generally speaking, a large amount of information means that there are too many uncertain factors behind it at this time). The formula of entropy is as follows:

Among them, p(xi) is the probability of occurrence of category xi, and n is the number of categories. It can be seen that the size of entropy is only related to the probability distribution of the variable.
The conditional entropy of Y under the condition of X refers to the amount of information (uncertainty) of the variable Y after the information of X. The calculation formula is as follows:

For example, when there are only categories A and B, p(A)=p(B)=0.5, the entropy is:

When there is only class A or only class B,

Therefore, when Entropy is maximum 1, it is the state with the worst classification effect, and when it is minimum 0, it is the state of complete classification. Because entropy equals zero is an ideal state, in general practical situations, entropy is between 0 and 1.

The continuous minimization of entropy is actually the process of improving the accuracy of classification.

1.3 Information gain

Information gain: the change of information before and after dividing the data set, calculate the information gain obtained by dividing the data set for each eigenvalue, and obtain the feature with the highest information gain is the best choice.

Define the information gain of attribute A to data set D as infoGain(D|A), which is equal to the entropy of D itself, minus the conditional entropy of D under the condition of given A, that is:

The meaning of information gain: After introducing attribute A, how much the uncertainty of the original data set D is reduced.

Calculate the information gain after the introduction of each attribute, and select the attribute that brings the largest information gain to D, which is the optimal partition attribute. Generally, the greater the information gain, the greater the "purity improvement" obtained by using attribute A to divide.

2 python implementation of ID3 algorithm

Taking the watermelon data set as an example
, watermalon.csv file content is as follows:

Read file data

import numpy as np
import pandas as pd
import math
data = pd.read_csv('work/watermalon.csv')
data

Calculate entropy

def info(x,y):
    if x != y and x != 0:
        # 计算当前情况的熵
        return -(x/y)*math.log2(x/y) - ((y-x)/y)*math.log2((y-x)/y)
    if x == y or x == 0:
        # 纯度最大,熵值为0
        return 0
info_D = info(8,17)
info_D

The result is:
0.9975025463691153

Calculate information gain

# 计算每种情况的熵
seze_black_entropy = -(4/6)*math.log2(4/6)-(2/6)*math.log2(2/6)
seze_green_entropy = -(3/6)*math.log2(3/6)*2
seze_white_entropy = -(1/5)*math.log2(1/5)-(4/5)*math.log2(4/5)

# 计算色泽特征色信息熵
seze_entropy = (6/17)*seze_black_entropy+(6/17)*seze_green_entropy+(5/17)*seze_white_entropy
print(seze_entropy)
# 计算信息增益
info_D - seze_entropy

The result is:
0.10812516526536531

Check the distribution of good and bad melons in each type of root

data.根蒂.value_counts()
# 查看每种根蒂中好坏瓜情况的分布情况
print(data[data.根蒂=='蜷缩'])
print(data[data.根蒂=='稍蜷'])
print(data[data.根蒂=='硬挺'])
gendi_entropy = (8/17)*info(5,8)+(7/17)*info(3,7)+(2/17)*info(0,2)
gain_col = info_D - gendi_entropy
gain_col

The basic information gain is: 0.142674959566793

Check the distribution of good and bad melons for each knocking sound

data.敲声.value_counts()
# 查看每种敲声中好坏瓜情况的分布情况
print(data[data.敲声=='浊响'])
print(data[data.敲声=='沉闷'])
print(data[data.敲声=='清脆'])
qiaosheng_entropy = (10/17)*info(6,10)+(5/17)*info(2,5)+(2/17)*info(0,2)
info_gain = info_D - qiaosheng_entropy
info_gain

See the distribution of good and bad melons in each texture

data.纹理.value_counts()
# 查看每种纹理中好坏瓜情况的分布情况
print(data[data.纹理=="清晰"])
print(data[data.纹理=="稍糊"])
print(data[data.纹理=="模糊"])
wenli_entropy = (9/17)*info(7,9)+(5/17)*info(1,5)+(3/17)*info(0,3)
info_gain = info_D - wenli_entropy
info_gain

Similarly, check the distribution of other columns, which will not be demonstrated here.

Draw a visual tree

import matplotlib.pylab as plt
import matplotlib

# 能够显示中文
matplotlib.rcParams['font.sans-serif'] = ['SimHei']
matplotlib.rcParams['font.serif'] = ['SimHei']

# 分叉节点,也就是决策节点
decisionNode = dict(boxstyle="sawtooth", fc="0.8")

# 叶子节点
leafNode = dict(boxstyle="round4", fc="0.8")

# 箭头样式
arrow_args = dict(arrowstyle="<-")


def plotNode(nodeTxt, centerPt, parentPt, nodeType):
    """
    绘制一个节点
    :param nodeTxt: 描述该节点的文本信息
    :param centerPt: 文本的坐标
    :param parentPt: 点的坐标,这里也是指父节点的坐标
    :param nodeType: 节点类型,分为叶子节点和决策节点
    :return:
    """
    createPlot.ax1.annotate(nodeTxt, xy=parentPt, xycoords='axes fraction',
                            xytext=centerPt, textcoords='axes fraction',
                            va="center", ha="center", bbox=nodeType, arrowprops=arrow_args)


def getNumLeafs(myTree):
    """
    获取叶节点的数目
    :param myTree:
    :return:
    """
    # 统计叶子节点的总数
    numLeafs = 0

    # 得到当前第一个key,也就是根节点
    firstStr = list(myTree.keys())[0]

    # 得到第一个key对应的内容
    secondDict = myTree[firstStr]

    # 递归遍历叶子节点
    for key in secondDict.keys():
        # 如果key对应的是一个字典,就递归调用
        if type(secondDict[key]).__name__ == 'dict':
            numLeafs += getNumLeafs(secondDict[key])
        # 不是的话,说明此时是一个叶子节点
        else:
            numLeafs += 1
    return numLeafs


def getTreeDepth(myTree):
    """
    得到数的深度层数
    :param myTree:
    :return:
    """
    # 用来保存最大层数
    maxDepth = 0

    # 得到根节点
    firstStr = list(myTree.keys())[0]

    # 得到key对应的内容
    secondDic = myTree[firstStr]

    # 遍历所有子节点
    for key in secondDic.keys():
        # 如果该节点是字典,就递归调用
        if type(secondDic[key]).__name__ == 'dict':
            # 子节点的深度加1
            thisDepth = 1 + getTreeDepth(secondDic[key])

        # 说明此时是叶子节点
        else:
            thisDepth = 1

        # 替换最大层数
        if thisDepth > maxDepth:
            maxDepth = thisDepth

    return maxDepth


def plotMidText(cntrPt, parentPt, txtString):
    """
    计算出父节点和子节点的中间位置,填充信息
    :param cntrPt: 子节点坐标
    :param parentPt: 父节点坐标
    :param txtString: 填充的文本信息
    :return:
    """
    # 计算x轴的中间位置
    xMid = (parentPt[0]-cntrPt[0])/2.0 + cntrPt[0]
    # 计算y轴的中间位置
    yMid = (parentPt[1]-cntrPt[1])/2.0 + cntrPt[1]
    # 进行绘制
    createPlot.ax1.text(xMid, yMid, txtString)


def plotTree(myTree, parentPt, nodeTxt):
    """
    绘制出树的所有节点,递归绘制
    :param myTree: 树
    :param parentPt: 父节点的坐标
    :param nodeTxt: 节点的文本信息
    :return:
    """
    # 计算叶子节点数
    numLeafs = getNumLeafs(myTree=myTree)

    # 计算树的深度
    depth = getTreeDepth(myTree=myTree)

    # 得到根节点的信息内容
    firstStr = list(myTree.keys())[0]

    # 计算出当前根节点在所有子节点的中间坐标,也就是当前x轴的偏移量加上计算出来的根节点的中心位置作为x轴(比如说第一次:初始的x偏移量为:-1/2W,计算出来的根节点中心位置为:(1+W)/2W,相加得到:1/2),当前y轴偏移量作为y轴
    cntrPt = (plotTree.xOff + (1.0 + float(numLeafs))/2.0/plotTree.totalW, plotTree.yOff)

    # 绘制该节点与父节点的联系
    plotMidText(cntrPt, parentPt, nodeTxt)

    # 绘制该节点
    plotNode(firstStr, cntrPt, parentPt, decisionNode)

    # 得到当前根节点对应的子树
    secondDict = myTree[firstStr]

    # 计算出新的y轴偏移量,向下移动1/D,也就是下一层的绘制y轴
    plotTree.yOff = plotTree.yOff - 1.0/plotTree.totalD

    # 循环遍历所有的key
    for key in secondDict.keys():
        # 如果当前的key是字典的话,代表还有子树,则递归遍历
        if isinstance(secondDict[key], dict):
            plotTree(secondDict[key], cntrPt, str(key))
        else:
            # 计算新的x轴偏移量,也就是下个叶子绘制的x轴坐标向右移动了1/W
            plotTree.xOff = plotTree.xOff + 1.0/plotTree.totalW
            # 打开注释可以观察叶子节点的坐标变化
            # print((plotTree.xOff, plotTree.yOff), secondDict[key])
            # 绘制叶子节点
            plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode)
            # 绘制叶子节点和父节点的中间连线内容
            plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key))

    # 返回递归之前,需要将y轴的偏移量增加,向上移动1/D,也就是返回去绘制上一层的y轴
    plotTree.yOff = plotTree.yOff + 1.0/plotTree.totalD


def createPlot(inTree):
    """
    需要绘制的决策树
    :param inTree: 决策树字典
    :return:
    """
    # 创建一个图像
    fig = plt.figure(1, facecolor='white')
    fig.clf()
    axprops = dict(xticks=[], yticks=[])
    createPlot.ax1 = plt.subplot(111, frameon=False, **axprops)
    # 计算出决策树的总宽度
    plotTree.totalW = float(getNumLeafs(inTree))
    # 计算出决策树的总深度
    plotTree.totalD = float(getTreeDepth(inTree))
    # 初始的x轴偏移量,也就是-1/2W,每次向右移动1/W,也就是第一个叶子节点绘制的x坐标为:1/2W,第二个:3/2W,第三个:5/2W,最后一个:(W-1)/2W
    plotTree.xOff = -0.5/plotTree.totalW
    # 初始的y轴偏移量,每次向下或者向上移动1/D
    plotTree.yOff = 1.0
    # 调用函数进行绘制节点图像
    plotTree(inTree, (0.5, 1.0), '')
    # 绘制
    plt.show()


if __name__ == '__main__':
    createPlot(mytree)

Summarize

Decision tree ID3 is a classic machine learning algorithm used to solve classification problems. It makes decisions by building a tree structure in the feature space and uses information gain as the dividing criterion. The key to the ID3 algorithm is to select the best attributes for partitioning to maximize information gain. By implementing the ID3 algorithm in Python, we can build an efficient and accurate decision tree model for classification prediction and decision analysis.


Reference
https://zhuanlan.zhihu.com/p/133846252
https://cuijiahua.com/blog/2017/11/ml_2_decision_tree_1.html
https://blog.csdn.net/tauvan/article/details/121028351

Guess you like

Origin blog.csdn.net/apple_52030329/article/details/131501100