机器学习实战(二)决策树DT(Decision Tree、ID3算法)

版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/zhq9695/article/details/83091285

目录

0. 前言

1. 信息增益(ID3)

2. 决策树(Decision Tree)

3. 实战案例

3.1. 隐形眼镜案例

3.2. 存储决策树

3.3. 决策树画图表示


学习完机器学习实战的决策树,简单的做个笔记。文中部分描述属于个人消化后的理解,仅供参考。

所有代码和数据可以访问 我的 github

如果这篇文章对你有一点小小的帮助,请给个关注喔~我会非常开心的~

0. 前言

决策树(Decision Tree)的执行流程很好理解,如下图所示(图源:西瓜书),在树上的每一个结点进行判断,选择分支,直到走到叶子结点,得出分类:

  • 优点:计算复杂度不高、输出结果易于理解、对缺失值不敏感
  • 缺点:可能会产生过拟合
  • 适用数据类型:数值型和标称型(数值型数据需要离散化)

决策树构建中,目标就是找到当前哪个特征在划分数据时起到决定性作用,划分数据有多种办法,如信息增益(ID3)、信息增益率(C4.5)、基尼系数(CART),本篇主要介绍信息增益(ID3算法)。

1. 信息增益(ID3)

首先,介绍香农熵(entropy),熵定义为信息的期望值,熵越高,说明信息的混乱程度越高

Ent(D)=-\sum_{k=1}^{\left|\gamma \right|}p(k)\log_{2}p(k)

其中,D 表示数据集,k 表示数据集中的每一个类别,p(k) 表示这个属于类别的数据占所有数据的比例。

信息增益(information gain)定义为原始的熵减去当前的熵,增益越大,说明当前熵越小,说明数据混乱程度越小

Gain(D,a)=Ent(D)-\sum_{v=1}^{V}\frac{\left|D^v\right|}{\left|D\right|}Ent(D^v)

其中,V 表示按照此特征划分的子集数量,v 表示第 v 个子集,Ent(D^v) 表示子集的信息熵,\frac{\left|D^v\right|}{\left|D\right|} 表示子集数据占所有数据的比例。

注:信息增益更偏向于选择取值较多的特征,这是它的缺点。

2. 决策树(Decision Tree)

算法流程可简单表示为:

  1. 遍历当前数据所有的特征,计算信息增益最大的特征,作为当前划分数据的结点,并去除此特征
  2. 对划分后每个分支上的子集继续进行步骤 1 
  3. 如果当前子集内的数据都是同一类型,则停止划分,标记叶子结点
  4. 如果子集内数据还未统一类型,而已经没有特征,则采用多数表决原则

3. 实战案例

以下将展示书中的三个案例的代码段,所有代码和数据可以在github中下载:

3.1. 隐形眼镜案例

# coding:utf-8
from math import log
import operator
import pickle

"""
隐形眼镜案例
"""


# 计算香农熵
def calcShannonEnt(dataSet):
    numEntries = len(dataSet)
    labelCounts = {}
    for featVec in dataSet:
        currentLabel = featVec[-1]
        if currentLabel not in labelCounts.keys():
            labelCounts[currentLabel] = 0
        labelCounts[currentLabel] += 1
    shannonEnt = 0.0
    for key in labelCounts:
        prob = float(labelCounts[key]) / numEntries
        shannonEnt -= prob * log(prob, 2)
    return shannonEnt


# 按照给定特征划分数据集
def splitDataSet(dataSet, axis, value):
    retDataSet = []
    # 只选择第 axis 列的值为 value 的数据
    # 去除这个特征,取数据[:axis] 和 [axis+1:] 段
    for featVec in dataSet:
        if featVec[axis] == value:
            reducedFeatVec = featVec[:axis]
            reducedFeatVec.extend(featVec[axis + 1:])
            retDataSet.append(reducedFeatVec)
    return retDataSet


# 选择最好的数据集划分方式
def chooseBestFeatureToSplit(dataSet):
    numFeatures = len(dataSet[0]) - 1
    baseEntropy = calcShannonEnt(dataSet)
    bestInfoGain = 0.0
    bestFeature = -1
    # 遍历每一个特征
    for i in range(numFeatures):
        featList = [example[i] for example in dataSet]
        uniqueVals = set(featList)
        newEntropy = 0.0
        # 遍历这个特征的所有特征值
        for value in uniqueVals:
            subDataSet = splitDataSet(dataSet, i, value)
            # 判断这个子集占所有数据集的比例
            prob = len(subDataSet) / float(len(dataSet))
            # 新的信息熵 = 所有子集的信息熵乘以比例再求和
            newEntropy += prob * calcShannonEnt(subDataSet)
        infoGain = baseEntropy - newEntropy
        if (infoGain > bestInfoGain):
            bestInfoGain = infoGain
            bestFeature = i
    return bestFeature


# 多数表决原则
def majorityCnt(classList):
    classCount = {}
    for vote in classList:
        if vote not in classCount.keys():
            classCount[vote] = 0
        classCount[vote] += 1
    sortedClassCount = sorted(classCount.items(),
                              key=operator.itemgetter(1), reverse=True)
    return sortedClassCount[0][0]


# 创建决策树
# labels 为特征的标签
def createTree(dataSet, labels):
    # 获取当前数据集最后一列的类别信息
    classList = [example[-1] for example in dataSet]
    # 如果最后一列都是一种类别
    if classList.count(classList[0]) == len(classList):
        return classList[0]
    # 如果当前数据集没有可划分的特征
    if len(dataSet[0]) == 1:
        return majorityCnt(classList)
    # 获取最好的划分数据集特征
    bestFeat = chooseBestFeatureToSplit(dataSet)
    bestFeatLabel = labels[bestFeat]
    myTree = {bestFeatLabel: {}}
    # 在特征标签中删除当前特征
    del (labels[bestFeat])
    # 获取这个特征的列,遍历此特征的所有特征值
    featValues = [example[bestFeat] for example in dataSet]
    uniqueVals = set(featValues)
    for value in uniqueVals:
        subLabels = labels[:]
        # 特征有几个取值,这个结点就有几个分支
        # 每个取值,都划分出子集,递归建树
        myTree[bestFeatLabel][value] = createTree(
            splitDataSet(dataSet, bestFeat, value), subLabels)
    return myTree


# 分类函数
def classify(inputTree, featLabels, testVec):
    # 获取第一个特征
    firstStr = list(inputTree.keys())[0]
    # 获取这个特征下的键值对的值
    secondDict = inputTree[firstStr]
    # 获取这个特征的索引
    featIndex = featLabels.index(firstStr)
    # 遍历每一个分支
    for key in secondDict.keys():
        if testVec[featIndex] == key:
            # 判断当前分支下是否还有分支
            if type(secondDict[key]).__name__ == 'dict':
                classLabel = classify(secondDict[key], featLabels, testVec)
            else:
                classLabel = secondDict[key]
    return classLabel


if __name__ == '__main__':
    fr = open('lenses.txt')
    lenses = [inst.strip().split('\t') for inst in fr.readlines()]
    lensesLabels = ['age', 'prescript', 'astigmatic', 'tearRate']
    lensesTree = createTree(lenses, lensesLabels)
    print(lensesTree)

3.2. 存储决策树

# 存储树
def storeTree(inputTree, filename):
    fw = open(filename, 'wb')
    pickle.dump(inputTree, fw)
    fw.close()


# 取出存储的树
def grabTree(filename):
    fr = open(filename, 'rb')
    return pickle.load(fr)

3.3. 决策树画图表示

# coding:utf-8
import matplotlib.pyplot as plt

# 解决显示中文问题
from pylab import *

mpl.rcParams['font.sans-serif'] = ['SimHei']

"""
决策树画图
"""


# 创建树的字典
def retrieveTree(i):
    listOfTrees = [{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}},
                   {'tearRate': {'reduced': 'no lenses', 'normal': {'astigmatic': {'yes': {
                       'prescript': {'hyper': {'age': {'pre': 'no lenses', 'young': 'hard', 'presbyopic': 'no lenses'}},
                                     'myope': 'hard'}}, 'no': {'age': {'pre': 'soft', 'young': 'soft', 'presbyopic': {
                       'prescript': {'hyper': 'soft', 'myope': 'no lenses'}}}}}}}}
                   ]
    return listOfTrees[i]


# 获取叶节点的数目
def getNumLeafs(myTree):
    numLeafs = 0
    firstStr = list(myTree.keys())[0]
    secondDict = myTree[firstStr]
    for key in secondDict.keys():
        if type(secondDict[key]).__name__ == 'dict':
            numLeafs += getNumLeafs(secondDict[key])
        else:
            numLeafs += 1
    return numLeafs


# 获取树的层数
def getTreeDepth(myTree):
    maxDepth = 0
    firstStr = list(myTree.keys())[0]
    secondDict = myTree[firstStr]
    for key in secondDict.keys():
        if type(secondDict[key]).__name__ == 'dict':
            thisDepth = 1 + getTreeDepth(secondDict[key])
        else:
            thisDepth = 1
        if thisDepth > maxDepth:
            maxDepth = thisDepth
    return maxDepth


# 使用文本注解绘制树节点
decisionNode = dict(boxstyle='sawtooth', fc='0.8')
leafNode = dict(boxstyle='round4', fc='0.8')
arrow_args = dict(arrowstyle='<-')


# 画节点
def plotNode(nodeTxt, centerPt, parentPt, nodeType):
    createPlot.ax1.annotate(nodeTxt, xy=parentPt, xycoords='axes fraction',
                            xytext=centerPt, textcoords='axes fraction',
                            va='center', ha='center', bbox=nodeType,
                            arrowprops=arrow_args)


# 在父子节点间填充文本信息
def plotMidText(cntrPt, parentPt, txtString):
    xMid = (parentPt[0] - cntrPt[0]) / 2.0 + cntrPt[0]
    yMid = (parentPt[1] - cntrPt[1]) / 2.0 + cntrPt[1]
    createPlot.ax1.text(xMid, yMid, txtString)


# 画树
def plotTree(myTree, parentPt, nodeTxt):
    numLeafs = getNumLeafs(myTree)
    depth = getTreeDepth(myTree)
    firstStr = list(myTree.keys())[0]
    cntrPt = (plotTree.xOff + (1.0 + float(numLeafs)) / 2.0 / plotTree.totalW,
              plotTree.yOff)
    plotMidText(cntrPt, parentPt, nodeTxt)
    plotNode(firstStr, cntrPt, parentPt, decisionNode)
    secondDict = myTree[firstStr]
    plotTree.yOff = plotTree.yOff - 1.0 / plotTree.totalD
    for key in secondDict.keys():
        if type(secondDict[key]).__name__ == 'dict':
            plotTree(secondDict[key], cntrPt, str(key))
        else:
            plotTree.xOff = plotTree.xOff + 1.0 / plotTree.totalW
            plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff),
                     cntrPt, leafNode)
            plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key))
    plotTree.yOff = plotTree.yOff + 1.0 / plotTree.totalD


# 主要画图函数
def createPlot(inTree):
    fig = plt.figure(1, facecolor='white')
    fig.clf()
    axprops = dict(xticks=[], yticks=[])
    createPlot.ax1 = plt.subplot(111, frameon=False, **axprops)
    plotTree.totalW = float(getNumLeafs(inTree))
    plotTree.totalD = float(getTreeDepth(inTree))
    plotTree.xOff = -0.5 / plotTree.totalW
    plotTree.yOff = 1.0
    plotTree(inTree, (0.5, 1.0), '')
    plt.show()


if __name__ == '__main__':
    myTree = retrieveTree(1)
    createPlot(myTree)

如果这篇文章对你有一点小小的帮助,请给个关注喔~我会非常开心的~

猜你喜欢

转载自blog.csdn.net/zhq9695/article/details/83091285