机器学习之决策树原理及实现(python)

版权声明:博客中难免不少纰漏甚至严重错误,希望大家指正,这是对我最大的帮助。同时本博客最大的目的也在于交流学习,而不在关注和传播。任重而道远,MrYx与您共勉。 https://blog.csdn.net/yexiaohhjk/article/details/83218514

前言

这篇博客只简略概述决策树原理,更多是逐步改善的代码并用于记载个人学习的过程。详细学习决策树方面的内容推荐看论文,周志华老师的西瓜书和机器学习实战,如有疏漏和错误请指正。

实现全部代码放在github,欢迎FollowStart!

决策树基本概念

决策树是一类非常常见用于分类和回归任务的一类机器学习算法,由于它存储结构和执行过程是在一颗树上从树的根结点开始做出符合判定条件的决策不停往下遍历直到树的叶子节点得出分类或者回归的结果,所以该算法很形象的称为决策树(decision tree)

比如:在下图这个问题里,根据是否满足前两列的两个条件判断一个东西是否是鱼?
在这里插入图片描述
根据这个表的数据,我们可以用决策树算法学习生成一颗如下图的决策树:
在这里插入图片描述
如图所示,每个非叶子的根节点都是判定属性名称,叶子节点上是判定的条件。

显然,这颗决策树是得来是根据我们先给出训练集及上表中的数据学习来的,所以最终的结论对应我们所希望的判定结果。而决策树学习的目的是为了产生一颗泛化能力强,即处理未见实例即测试集能力强的决策树。

其生成的策略也非常容易分析得到,同时由于是树型结构,生成过程也满足“分而治之”的思想,不理解也没关系,简言之就是在树里常见先序递归遍历求树高度,宽度等操作,在决策树里都是可行且复杂度分析都是可观的。

决策树生成流程 如下图所示:

在这里插入图片描述

显然,对树形算法的敏感和研读决策树生成流程图(不熟的需要反思数据结构怎么学的),该生成过程是一个递归的过程。而递归算法最重要是分析出递归终止的条件,这里面只有三种情况会返回空,导致递归终止。

如图所示:

  • (1)当前节点包含样本的结果属于同一类别,不需划分。 则直接将该节点标为同一类别的叶子节点
  • (2)当前所有属性划分完了,或者所有样本在所有属性上结果相同,无法划分。则也将该节点标记为叶子节点,该叶子节点类型由当前包含样本结果最多的类型。
  • (3)当前按照属性划分样本集合为空,没有样本可以划分。则同样把当前节点标记为叶子节点,但类型由其父节点出现最多类型决定。

分析完递归终止条件,我们可以发现整个算法最关键和没确定的部分只有第八行,即当前如何在这么多样本里选择用什么属性特征作为当前结点的划分条件从而生成决策树的泛化和判定能力是最优的呢?(明显每个节点选择不同条件生成决策树形状是不一样的,泛化和判定性能当然也不同)

决策树的特征划分标准

在决策树划分标准这一步里,根据不同划分算法的提出,这里就演化出了很多不同版本的决策树!
常见的划分标准有三种: (1) 信息增益 (2)增益率 (3)基尼指数

一般而言,随着划分过程不断进行,我们希望决策树的分支节点所包含的样本尽可能属于同一类别,即结点纯度(purity)越来越高。

信息增益:

首先不得不提信息熵的概念,而我这学期正好上了《信息论》,还是有点基础。其实没有基础也不怕,不用把这个词想的很高大上复杂化,其实信息熵思想很简单概括就是反映一个事情包含平均不确定信息的信息量大小。
具体分类分析就是一个事件如果没有确定发生,信息熵就表示这件事情不确定发生的概率大小。如果确定已经发生了,信息熵就表示这件事所包含信息量的大小。

所以信息熵(information enrtopy)是度量样本集合纯度最常用的一种指标。 假定当前样本集合D中第k类样本所占比例为pk,则样本集合D的信息熵定义为:
在这里插入图片描述

假定通过属性划分样本集D,产生了V个分支节点,v表示其中第v个分支节点,易知:分支节点包含的样本数越多,表示该分支节点的影响力越大。故可以计算出划分后相比原始数据集D获得的“信息增益”(information gain)。

在这里插入图片描述
每一次选择哪个特征作为划分标准的时候采用信息增益值最大作为准则的决策树称为ID3。同时我们下面实现的决策树代码即采用是信息增益准则,即实现是ID3类型的决策树。

更多关于如何使用信息增益作为划分标准的细节推荐看西瓜书这里不啰嗦了。

增益率

实际上,观察信息增益的公式,我们可以发现信息增益准则对取值数目较多的属性有所偏好,为减少这种偏好可能带来的不利影响,著名的C4.5决策树算法不直接采用信息增益而是和增益率结合起来选择划分最优属性。

增益率公式定义为:D表示所有样本,a表示选择的属性。
在这里插入图片描述
称为属性a的固有值.属性a可能取值的数目越多,则IV(a)值通常越大。所以需要注意的是,增益率准则对可取值数目较小的属性有所偏好。因此,C4.5并不是直接选择增益率最大的候选划分属性,而是使用一个启发式:先从划分属性中找出信息增益高于平均水平的属性,再从中选择增益率最高的。(那我好奇,如果我先从增益率高于平均水平里找出信息增益最高的,效果会不会更好?

基尼指数

CART决策树使用基尼指数来选择划分属性,采用和上面相同符号的公式:
在这里插入图片描述

直观来说基尼指数反映的是从样本集D中随机抽取两个样本,其类别标记不一致的概率,因此Gini(D)越小则数据集的纯度越高

进而,使用属性α划分后的基尼指数为:
在这里插入图片描述

于是,我们在候选属性集合里选择哪个使划分后基尼指数最小的属性作为最优划分属性!

决策树的剪枝

剪枝是决策树算法对付“过拟合”的主要手段。主要分预剪枝:从根往下和后剪枝:待树生成后从叶子往上。
思想很简单,结合上面划分标准,每次求当前节点划分是否增益率或者信息增益有所改善,没有就减去,保持最简洁的树形状。
具体等我代码里补充后,在详细一点说明。

决策树的代码实现

代码是参照《机器学习实战》码出来,简洁优雅,修改个别语言版本错误,之后亲测完美运行。

这里决策树,我们并不构造新的数据结构,而是使用python语言内嵌的数据结构dict字典巧妙简单的存储树的节点信息。

首先是决策树构造部分代码:
creataDataSet()函数创建了一个上面判断是否是鱼的例子数据,用作其他函数测试。
采用是信息增益作为划分标准,其他函数的功能即实现都可以看代码及注释弄懂,不懂可以留言讨论。

#!/usr/bin/env python
# encoding: utf-8

'''
@author: MrYx
@contact: [email protected]
@author github: https://github.com/MrYxJ
@file: UseDecisionTree.py
@time: 18-10-20 下午1:50
'''


from math import log
import operator
import PlotDecisionTree

def createDataSet():
    dataSet = [[1,1,'yes'],
               [1,1,'yes'],
               [1,0,'no'],
               [0,1,'no'],
               [0,0,'no']
    ]
    labels  = ['no surfaceing','flippers']
    return dataSet, labels


def calcShannonEnt(dataSet):
    """
    计算给定数据集的信息熵
    :param dataSet:
    :return:
    """
    numEntries = len(dataSet)
    labelCounts = {}
    for featVec in dataSet :
        currentLabel = featVec[-1]
        if currentLabel not in labelCounts.keys():
            labelCounts[currentLabel] = 0
        labelCounts[currentLabel] += 1
    shannonEnt = 0.0
    for key in labelCounts:
        prob = float(labelCounts[key]) / numEntries
        shannonEnt -= prob * log(prob, 2) #信息熵定义公式

    return shannonEnt

def splitDataSet(dataSet, axis ,value):
    """
    按照给定的特征划分数据集
    :param dataSet:
    :param axis:
    :param value:
    :return:
    """
    retDataSet = []
    for featVec in dataSet:
        if featVec[axis] == value:
            reducedFeatVec = featVec[:axis] # 抽取从0到这个数之前的数
            reducedFeatVec.extend(featVec[axis+1: ]) #抽取从这个数之后一位的所有数t
            retDataSet.append(reducedFeatVec)
    return retDataSet

def chooseBestFeatureToSplit(dataSet):
    """
    选择最好的数据集划分方式
    :param dataSet:
    :return:
    """
    numFeatures = len(dataSet[0]) -1
    baseEntropy = calcShannonEnt(dataSet)
    bestInfoGain = 0.0; bestFeature = -1

    for i in range(numFeatures):
        featList = [example[i] for example in dataSet]
       # print('featList:', featList)
        uniqueVals = set(featList)
       # print('uniqueVal:' ,uniqueVals)
        newEntropy = 0.0
        for value in uniqueVals:
            subDataSet = splitDataSet(dataSet, i, value)
            prob = len(subDataSet) / float(len(dataSet))
            newEntropy += prob * calcShannonEnt(subDataSet)
        infoGain = baseEntropy -newEntropy
        if infoGain > bestInfoGain:
            bestInfoGain = infoGain
            bestFeature = i
    return bestFeature

def majorityCnt(classList):
    """
    当数据集已经处理所有属性,或者所有属性值都相同时,选出结果出现最多作为叶子节点结果
    :param classList:
    :return:
    """
    classCount = {}
    for vote in classList:
        if vote not in classCount.keys(): classCount[vote] = 0
        classCount[vote] += 1
    sortedClassCount = sorted(classCount.iteritems() , key = operator.itemgetter(1), reverse = True)
    print('classcount:', sortedClassCount)
    return sortedClassCount[0][0]

def createTree(dataSet, labels):
    classList = [example[-1] for example in dataSet]
    if classList.count(classList[0])  == len(classList):  #当前所有样本都属于同一类
        return classList[0]

    if len(dataSet[0]) == 1:     # 当前属性集为空,只有一列结果
        return majorityCnt(classList)  # 返回出现结果频率最大作为结果
    bestFeat = chooseBestFeatureToSplit(dataSet)
    bestFeatLabel = labels[bestFeat]
    myTree = {bestFeatLabel:{}}
    del(labels[bestFeat])
    featValues = [example[bestFeat] for example in dataSet]
    uniqueVals = set(featValues)
    for value in uniqueVals:
        subLables = labels[:] # 在python 里面 list是引用类型的变量,所以防止改变用一个新的变量代替。
        myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet, bestFeat, value), subLables)
    return myTree

def classify(inputTree, featLabels, testVec):
    """
    使用决策树分类,传入一颗以dict字典形式建好的决策树和标签,测试数据,输出决策树分类的结果。
    :param inputTree:
    :param featLabels:
    :param testVec:
    :return:
    """
    firstStr = list(inputTree.keys())[0]
    secondDict = inputTree[firstStr]
    featIndex = featLabels.index(firstStr)
    for key in secondDict.keys():
        if testVec[featIndex] == key:
            if type(secondDict[key]).__name__ == 'dict':
                classLabel = classify(secondDict[key], featLabels, testVec)
            else : classLabel = secondDict[key]
    return classLabel


def storeTree(inputTree, filename):
    '''
    使用pickle模块存储决策树
    :param inputTree:
    :param filename:
    :return:
    '''
    import pickle
    fw = open(filename, 'wb')
    pickle.dump(inputTree, fw)
    fw.close()


def grabTree(filename):
    '''
    使用pickle模块导入存储的决策树数据
    :param filename:
    :return:
    '''
    import pickle
    fr = open(filename,'rb')
    return pickle.load(fr)

if __name__ == '__main__':
    print("Test begin:")
    myDat, labels = createDataSet() 
    print(myDat)
    print(calcShannonEnt(myDat))
    print('best choice index:',chooseBestFeatureToSplit(myDat))
    myTree = createTree(myDat, labels)
    print(myTree)
    PlotDecisionTree.createPlot(myTree)
    print(classify(myTree,createDataSet()[1],myDat[2]))

第二部分利用matplot库类比决策树生成过程注意每个节点线的坐标位置递归的画出决策树的图形。有些细节我不是很懂,但我觉得画图这部分直接拿来用就好啦!

#!/usr/bin/env python
# encoding: utf-8

'''
@author: MrYx
@contact: [email protected]
@author github: https://github.com/MrYxJ
@file: UseDecisionTree.py
@time: 18-10-20 下午2:21
'''

import matplotlib.pyplot as plt

decisionNode = dict(boxstyle = "sawtooth" ,fc = "1.8")
leafNode = dict(boxstyle = "round4" , fc = "1.8")
arrow_args = dict(arrowstyle = "<-")

def plotMidText(cntrPt, parentPt, txtString):
    """
    在父子节点中填充文本信息
    :param cntrPt:
    :param parentPt:
    :param txtString:
    :return:
    """
    xMid = (parentPt[0] - cntrPt[0]) / 2.0 + cntrPt[0]
    yMid = (parentPt[1] - cntrPt[1]) / 2.0 + cntrPt[1]
    createPlot.ax1.text(xMid, yMid, txtString)

def plotTree(myTree, parentPt, nodeTxt):
    """
    :param myTree:
    :param parentPt:
    :param nodeTxt:
    :return:
    """
    numLeafs = getNumLeafs(myTree)
    depth = getTreeDepth(myTree)
    firstStr = list(myTree.keys())[0]
    cntrPt = (plotTree.xoff + (1.0 + float(numLeafs)) / 2.0 / plotTree.totalW, plotTree.yoff)
    plotMidText(cntrPt, parentPt, nodeTxt)
    plotNode(firstStr, cntrPt, parentPt, decisionNode)
    secondDict = myTree[firstStr]
    plotTree.yoff = plotTree.yoff - 1.0 / plotTree.totalD
    for key in secondDict.keys():
        if type(secondDict[key]).__name__ == 'dict':
            plotTree(secondDict[key], cntrPt, str(key))
        else:
            plotTree.xoff = plotTree.xoff + 1.0 / plotTree.totalW
            plotNode(secondDict[key], (plotTree.xoff, plotTree.yoff), cntrPt, leafNode)
            plotMidText((plotTree.xoff, plotTree.yoff), cntrPt, str(key))
    plotTree.yoff = plotTree.yoff + 1.0/ plotTree.totalD

def plotNode(nodeTxt , centerPt , parentPt, nodeType):
    createPlot.ax1.annotate(nodeTxt,
                            xy = parentPt,
                            xycoords = 'axes fraction',
                            xytext = centerPt,
                            textcoords = 'axes fraction',
                            va="center",
                            ha = "center" ,
                            bbox = nodeType ,
                            arrowprops = arrow_args)

def getNumLeafs(myTree):
    """
    获取叶子节点个数
    :param myTree:
    :return:
    myTree格式 :嵌套的字典
    eg : {'no surrfaceing:':{0 : 'no',1 : {'flippers':{0: 'no',1:'yes'} } }
    """

    numLeafs = 0
    firstStr = list(myTree.keys())[0]
    secondDict = myTree[firstStr]
    for key in secondDict.keys() :
        if type(secondDict[key]).__name__ == 'dict' : # 测试节点类型是否为字典
            numLeafs += getNumLeafs(secondDict[key])
        else :
            numLeafs += 1
    return numLeafs

def getTreeDepth(myTree):
    """
    获取叶子层数
    :param myTree:
    :return:
    """
    maxDepth = 1
    firstStr = list(myTree.keys())[0]
    secondDict = myTree[firstStr]
    for key in secondDict.keys():
        if type(secondDict[key]).__name__ == 'dict':
            maxDepth = max(maxDepth, getTreeDepth(secondDict[key]) + 1)

    return maxDepth

def retieveTree(i):
    """
    存储两课决策树,方便测试使用。
    :param i: 
    :return: 
    """
    listOfTrees = [{ 'no surfacing':{0: 'no', 1:{'flippers':{0:'no', 1:'yes',2:'no','flag':'yes'}}}},
                     {'no surfacing':{0:'no',1:{'flippers':{0:{'head':{0:'no',1:'yes'}},1:'no'}}}
    }]
    return listOfTrees[i]


def createPlot(inTree):
    fig = plt.figure(1, facecolor='white')
    fig.clf()
    axprops = dict(xticks=[], yticks=[])
    createPlot.ax1 = plt.subplot(111, frameon=False, **axprops)
    plotTree.totalW = float(getNumLeafs(inTree))
    plotTree.totalD = float(getTreeDepth(inTree))
    plotTree.xoff = -0.5 / plotTree.totalW
    plotTree.yoff = 1.0
    plotTree(inTree, (0.5, 1.0), '')  # 树的引用作为父节点,但不画出来,所以用''
    plt.show()


if __name__ == '__main__':
    #createPlot()
    myTree = retieveTree(0)
    print(myTree)
    createPlot(myTree)

利用lenses数据测试以上两部分代码就不贴出,全部放在github,欢迎FollowStart!

后续更新

  • 剪枝部分的代码待完善上去
  • 目前决策树还只能处理离散数据类型,连续的数据部分待完善
  • 感觉dict构建决策树节点信息要过于简略,很多操作不能实现,待完善实现建数据结构存储决策树节点

本文参考

周志华《机器学习》
李航《机器学习实战》

猜你喜欢

转载自blog.csdn.net/yexiaohhjk/article/details/83218514