["Machine Learning" Zhou Zhihua (Watermelon Book)] Python Implementation of Construction, Visualization and Prediction of ID3 Decision Tree Based on Information Gain

Introduction

GitHub address: https://github.com/One1h/DecisionTree (I hope everyone can like it more)

Based on the book "Machine Learning" by Zhou Zhihua, this article hand-written the construction of ID3 decision tree based on information gain. I hope everyone can put forward valuable opinions, learn together and make progress together!

1. Build a decision tree-DecisionTree.py

1. Idea

basic algorithm

The Basic Algorithm of ID3 Decision Tree Based on Information Gain
Decision Tree Learning Basic Algorithm

information entropy

"Information entropy" is the most commonly used indicator to measure the purity of a sample set. Assuming that the proportion of the kth class sample in the current sample set D is Pk (k = 1, 2,..., |Y|), then the information entropy of D is defined as:
Information entropy calculation formula

The smaller the value of Ent(D), the higher the purity of D.

information gain

Assuming that the discrete attribute a has V possible values ​​a1, a2,..., aV, if a is used to divide the sample set D, V branch nodes will be generated, and the vth branch node contains all the attributes in D The sample whose value is av on a is denoted as Dv. At this time, the information entropy of Dv can be calculated. At the same time, considering that the number of samples contained in different branch nodes is different, assign weights to branch nodes |Dv|/|D |, that is, the more the number of samples, the greater the influence of the branch node, so the "information gain" (information gain) obtained by dividing the sample set D with attribute a can be calculated: generally speaking, the greater the information gain
Information gain calculation formula u
, It means that the "purity improvement" obtained by using attribute a for division is greater. Therefore, we can use information gain to select the partition attribute of the decision tree.

Decision tree storage structure

A multi-tree is used to store the decision tree, and a list is used to store the sub-nodes of each node to achieve a tree with dynamic-length sub-nodes.
multi-fork decision tree

2. Code

import math
from copy import copy
from typing import List
import PlotTree as pt


# 建立数据集
def createDataSet():
    dataSet = [
        # 1
        ['青绿', '蜷缩', '浊响', '清晰', '凹陷', '硬滑', '好瓜'],
        # 2
        ['乌黑', '蜷缩', '沉闷', '清晰', '凹陷', '硬滑', '好瓜'],
        # 3
        ['乌黑', '蜷缩', '浊响', '清晰', '凹陷', '硬滑', '好瓜'],
        # 4
        ['青绿', '蜷缩', '沉闷', '清晰', '凹陷', '硬滑', '好瓜'],
        # 5
        ['浅白', '蜷缩', '浊响', '清晰', '凹陷', '硬滑', '好瓜'],
        # 6
        ['青绿', '稍蜷', '浊响', '清晰', '稍凹', '软粘', '好瓜'],
        # 7
        ['乌黑', '稍蜷', '浊响', '稍糊', '稍凹', '软粘', '好瓜'],
        # 8
        ['乌黑', '稍蜷', '浊响', '清晰', '稍凹', '硬滑', '好瓜'],
        # 9
        ['乌黑', '稍蜷', '沉闷', '稍糊', '稍凹', '硬滑', '坏瓜'],
        # 10
        ['青绿', '硬挺', '清脆', '清晰', '平坦', '软粘', '坏瓜'],
        # 11
        ['浅白', '硬挺', '清脆', '模糊', '平坦', '硬滑', '坏瓜'],
        # 12
        ['浅白', '蜷缩', '浊响', '模糊', '平坦', '软粘', '坏瓜'],
        # 13
        ['青绿', '稍蜷', '浊响', '稍糊', '凹陷', '硬滑', '坏瓜'],
        # 14
        ['浅白', '稍蜷', '沉闷', '稍糊', '凹陷', '硬滑', '坏瓜'],
        # 15
        ['乌黑', '稍蜷', '浊响', '清晰', '稍凹', '软粘', '坏瓜'],
        # 16
        ['浅白', '蜷缩', '浊响', '模糊', '平坦', '硬滑', '坏瓜'],
        # 17
        ['青绿', '蜷缩', '沉闷', '稍糊', '稍凹', '硬滑', '坏瓜']
    ]

    # 特征值列表
    labels = ['色泽', '根蒂', '敲击', '纹理', '脐部', '触感']

    # 特征对应的所有可能的情况
    labels_full = {
    
    }

    for i in range(len(labels)):
        labelList = [example[i] for example in dataSet]
        uniqueLabel = set(labelList)
        labels_full[labels[i]] = uniqueLabel

    return dataSet, labels, labels_full


# 多叉树
class BTreeNode(object):
    def __init__(self, parent=None, keyword=None, child_nodes=[]):
        '''parent:上一层划分属性的具体属性值,如:”浅白“
        keyeyword:此节点的划分属性或label,如:“颜色”
        child_nodes:根据此节点属性的不同属性值划分的子节点集'''
        self.parent = parent
        self.keyword = keyword
        self.child_nodes = child_nodes

    def getkeyword(self):
        return self.keyword

    def addchild(self, node):
        self.child_nodes.append(node)

    def setkeyword(self, keyword):
        self.keyword = keyword

    def setparent(self, parent):
        self.parent = parent

    def shownode(self):
        print("parent:{}\nkeyword:{}\nchild_nodes: ".format(self.parent, self.keyword))
        for node in self.child_nodes:
            print(node.parent, node.keyword)
        print()


# 计算信息熵
def Entropy(pk: float) -> float:
    if pk == 0.0: return 0.0
    return -1 * pk * math.log(pk, 2)


# 计算信息增益
def Gain(D: List[int], Ent: float) -> float:
    G = Ent
    for Dv in D:
        G -= abs(Dv / sum(D)) * Entropy(Dv / sum(D))
    return G


# 获取最佳划分属性
def BestAttribute(dataSet_, labels_, labels_full_):
    # 根节点信息熵计算
    temp = []
    D_t = [0, 0]
    for i, data in enumerate(dataSet_):
        temp.append(i + 1)
        if data[-1] == '好瓜':
            D_t[0] += 1
        if data[-1] == '坏瓜':
            D_t[1] += 1
    Ent = Entropy(D_t[0] / len(temp)) + Entropy(D_t[1] / len(temp))

    # 初始化样本集和信息熵列表
    Gains = []
    for ind, label in enumerate(labels_):
        l = len(labels_full_[label])
        G = Ent
        label_t = list(labels_full_[label])
        D = []
        Ents = []
        for i in range(l):
            D.append([])
            Ents.append(0)

        # 按属性划分Dv
        for i, data in enumerate(dataSet_):
            attribute_ind = label_t.index(data[ind])
            D[attribute_ind].append(i + 1)

        # 计算Dv中各类别数量
        Dv = []
        for i in D:
            temp = [0, 0]
            for j in i:
                if dataSet_[j - 1][-1] == '好瓜':
                    temp[0] += 1
                if dataSet_[j - 1][-1] == '坏瓜':
                    temp[1] += 1
            Dv.append(temp)

        # 计算信息熵
        for i, data in enumerate(Dv):
            good, bad = data
            total = good + bad
            if total != 0:
                Ents[i] = Entropy(good / total) + Entropy(bad / total)

        # 计算信息增益
        for i, data in enumerate(Ents):
            G -= (Dv[i][0] + Dv[i][1]) / len(dataSet_) * data
        Gains.append(G)

    # 寻找最大信息熵的属性
    label_num = 0
    for i, g in enumerate(Gains):
        if g > Gains[label_num]:
            label_num = i

    return labels_[label_num], Gains[label_num]


# 若全为同一类别,返回此类叶结点
def SameClass(dataset_):
    # 若全为同一类别,返回此类叶结点
    label = ''
    same_class = True
    for i, data in enumerate(dataset_):
        if i == 0:
            continue
        if data[-1] != dataset_[i - 1][-1]:
            same_class = False
            break
    if same_class:
        label = dataset_[0][-1]

    return same_class, label


# 属性为空 或 样本在属性上取值相同
def NoneOrSameattr(dataset_, labels_):
    if labels_ != []:
        for i in range(len(dataset_)-2):
            for j in range(i+1, len(dataset_)-1):
                if dataset_[i][:-1] != dataset_[j][:-1]:
                    return False

    return True


# 返回最多类别
def MostClass(dataset_):
    good, bad = 0, 0
    for data in dataset_:
        if data[-1] == '好瓜':
            good += 1
        if data[-1] == '坏瓜':
            bad += 1
    label = '好瓜' if good >= bad else '坏瓜'

    return label


# 对属性划分后不同子集继续生成分支结点
def GetSubNode(dataset_, labels_, labels_full_, best_attr):
    root = BTreeNode(keyword=best_attr)
    subnodes = []
    ind = labels_.index(best_attr)
    # 根据划分属性的不同属性值,对不同属性值的子集进行子树生成
    for attr in labels_full_[best_attr]:
        subtree = BTreeNode()
        subdataset = []
        for i, data in enumerate(dataset_):
            if data[ind] == attr:
                temp = copy(data)
                temp.pop(ind)
                subdataset.append(temp)

        # 该属性值子集为空,设为样本最多的类别
        if not subdataset:
            label = MostClass(dataset_)
            subtree.setkeyword(label)

        # 该属性值子集不为空,继续进行子决策树生成
        else:
            sublabels_full = copy(labels_full_)
            if best_attr in sublabels_full:
                sublabels_full.pop(best_attr)

            sublabels = copy(labels_)
            if best_attr in sublabels:
                sublabels.remove(best_attr)

            subtree = TreeGenerate(subdataset, sublabels, sublabels_full)

        subtree.setparent(attr)
        subnodes.append(subtree)

    return subnodes


# 生成决策树
def TreeGenerate(dataset_, labels_, labels_full_):
    root = BTreeNode()
    # 若全为同一类别,返回此类叶结点
    flag, label = SameClass(dataset_)
    if flag:
        root.setkeyword(label)
        return root

    # 属性为空 或 样本在属性上取值相同,返回最多类别
    if NoneOrSameattr(dataset_, labels_):
        label = MostClass(dataset_)
        root.setkeyword(label)
        return root

    # 选择最优划分属性
    best_attr, gain = BestAttribute(dataset_, labels_, labels_full_)
    root.setkeyword(best_attr)

    # 对属性划分后不同子集继续生成分支结点
    root.child_nodes = GetSubNode(dataset_, labels_, labels_full_, best_attr)
    return root


# 决策树预测
def test(data, dataset, label, labels_full, tree):
    res = ''
    # 遍历决策树,直到得到label
    while res not in ['坏瓜', '好瓜']:
        # 获取划分属性
        attr_divide = tree.keyword
        ind = label.index(attr_divide)

        for node in tree.child_nodes:
            #根据属性值进行划分
            if node.parent == data[ind]:
                tree = node
                res = node.keyword
                break

    return res



if __name__ == '__main__':
    dataSet, labels, labels_full = createDataSet()
    tree = TreeGenerate(dataSet, labels, labels_full)
    pt.createPlot(tree)
    data = ['青绿', '蜷缩', '浊响', '清晰', '凹陷', '硬滑', '好瓜']
    print(test(data, dataSet, labels, labels_full, tree))

2. Decision tree visualization - PlotTree.py

1. Idea

Draw a dendrogram using matplotlib

2. Code

import matplotlib.pyplot as plt

# 定义matplotlib的字体
plt.rcParams['font.sans-serif'] = ['Droid Sans Fallback']
# boxstyle为文本框的类型,sawtooth是锯齿形,fc是边框线粗细,也可写作 decisionNode={boxstyle:'sawtooth',fc:'0.8'}
decisionNode = dict(boxstyle="round", fc="0.8")
# 定义决策树的叶子结点的描述属性
leafNode = dict(boxstyle="circle", fc="0.8")
# 定义决策树的箭头属性
arrow_args = dict(arrowstyle="<-")


# nodeTxt为要显示的文本,centerPt为文本的中心点,箭头所在的点,parentPt为指向文本的点
def plotNode(nodeTxt, centerPt, parentPt, nodeType):
    createPlot.ax1.annotate(nodeTxt, xy=parentPt, xycoords='axes fraction',
                            xytext=centerPt, textcoords='axes fraction',
                            va="bottom", ha="center",
                            bbox=nodeType, arrowprops=arrow_args)


# 获取叶节点的数目
def getNumLeafs(myTree):
    # 定义叶子结点数目
    numLeaf = 0
    # 得到根据第一个特征分类的结果
    nodes = myTree.child_nodes
    # 遍历得到的子节点
    for node in nodes:
        # 如果node为一个决策树结点,非子节点
        if node.child_nodes:
            # 则递归的计算nodes中的叶子结点数,并加到numLeafs上
            numLeaf += getNumLeafs(node)
        else:
            numLeaf += 1
    # 返回求的叶子结点数目
    return numLeaf


# 获取树的层数
def getTreeDepth(myTree):
    # 定义树的深度
    maxDepth = 0
    # 得到第一个特征分类的结果
    nodes = myTree.child_nodes
    for node in nodes:
        # 如果node为一个决策树结点
        if node.child_nodes:
            thisDepth = 1 + getTreeDepth(node)
        # 如果node为一个决策树结点,非子节点
        else:
            # 则将当前树的深度设为1
            thisDepth = 1
        # 比较当前树的深度与最大数的深度
        if thisDepth > maxDepth:
            maxDepth = thisDepth
    # 返回树的深度
    return maxDepth


# 绘制中间文本
def plotMidText(cntrPt, parentPt, txtString):
    # 求中间点的横坐标
    xMid = (parentPt[0] - cntrPt[0]) / 2.5 + cntrPt[0]
    # 求中间点的纵坐标
    yMid = (parentPt[1] - cntrPt[1]) / 2.5 + cntrPt[1]
    # 绘制树结点
    createPlot.ax1.text(xMid, yMid, txtString)


# 绘制决策树
def plotTree(myTree, parentPt, nodeTxt):
    # 定义并获得决策树的叶子结点数
    numLeafs = getNumLeafs(myTree)
    # 得到第一个特征
    firstStr = myTree.keyword
    # 计算坐标,x坐标为当前树的叶子结点数目除以整个树的叶子结点数再除以3,y为起点
    cntrPt = (plotTree.xOff + (1.0 + numLeafs) / len(myTree.child_nodes) / plotTree.totalW, plotTree.yOff)
    # 绘制决策树结点,也是当前树的根结点
    if parentPt == (0, 0):
        parentPt = cntrPt
    plotMidText(cntrPt, parentPt, nodeTxt)
    plotNode(firstStr, cntrPt, parentPt, decisionNode)
    # 根据第一个特征找到子节点
    nodes = myTree.child_nodes
    # 因为进入了下一层,所以y的坐标要变 ,图像坐标是从左上角为原点
    plotTree.yOff = plotTree.yOff - 1.0 / plotTree.totalD
    # 遍历字节带你
    for node in nodes:
        # 如果node为一棵子决策树,非叶子节点
        if node.child_nodes:
            # 递归的绘制决策树
            plotTree(node, cntrPt, node.parent)
        # node为叶子结点
        else:
            # 计算叶子结点的横坐标
            plotTree.xOff = plotTree.xOff + 1.0 / plotTree.totalW
            # 绘制叶子结点
            plotNode(node.keyword, (plotTree.xOff, plotTree.yOff), cntrPt, leafNode)
            # 特征值
            plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, node.parent)
    # 计算纵坐标
    plotTree.yOff = plotTree.yOff + 1.0 / plotTree.totalD


# 主函数 绘图
def createPlot(inTree):
    # 定义一块画布
    fig = plt.figure(1, facecolor='white')
    # 清空画布
    fig.clf()
    # 定义横纵坐标轴,无内容
    axprops = dict(xticks=[], yticks=[])
    # 绘制图像,无边框,无坐标轴
    createPlot.ax1 = plt.subplot(111, frameon=False, **axprops)
    # plotTree.totalW保存的是树的宽
    plotTree.totalW = float(getNumLeafs(inTree))
    # plotTree.totalD保存的是树的高
    plotTree.totalD = float(getTreeDepth(inTree))
    # 决策树起始横坐标
    plotTree.xOff = -0.5 / plotTree.totalW
    # 决策树的起始纵坐标
    plotTree.yOff = 1.0
    # 绘制决策树
    plotTree(inTree, (0, 0), '')
    # 显示图像
    plt.savefig('tree.jpg')

3. Visualize the results

Decision Tree Visualization Results

Summarize

The knowledge points of the decision tree are more than these, and on this basis, you can also add:

  • Conditions for dividing attribute selection, such as gain rate, Gini index, etc.;
  • Add pruning treatments such as pre-pruning and post-pruning;
  • continuous value processing;
  • Missing value handling.

Guess you like

Origin blog.csdn.net/lucifer479/article/details/115867687