《机器学习实战》分类篇02.决策树

决策树

简单了解决策树,如下图,正方形代表判断模块(decision block),椭圆形代表终止模块(terminating block),从判断模块引出的左右箭头称作分支(branch),它可以到达另一个判断模块或终止模块。

上节学习的k-近邻算法可以完成很多分类任务,但是最大的缺点是无法给出数据的内在含义,决策树的主要优势就在于数据形式非常容易理解。
决策树很多任务都是为了数据中所蕴含的知识信息,因此决策树可以使用不熟悉的数据集合,并从中提取一系列规则,机器学习算法最终使用这些机器从数据集中创造的规则。

1.决策树的构造

优点:计算复杂度不高,输出结果易于理解,对中间值的缺失不敏感,可以处理不相关特征数据。
缺点:可能会产生过度匹配问题。
适用数据类型:数值型和标称型
创建分支的伪代码函数createBranch()如下:

检测数据集中的每个子项是否属于同一分类
	If so return 类标签;
	Else
		寻找划分数据集的最好特征
		划分数据集
		创建分支节点
			for 每个划分的子集
				调用函数createBranch并增加返回结果到分支节点中
		return 分支节点

上面的伪代码createBranch是一个递归函数,在倒数第二行直接调用了自己。
决策树的一般流程

  1. 收集数据:可以使用任何方法;
  2. 准备数据:数构造算法只适用于标称型数据,因此数值型数据必须离散化;
  3. 分析数据:可以使用任何方法,构造树完成之后,我们应该检查图形是否符合预期
  4. 训练算法:构造树的数据结构;
  5. 测试算法:使用经验树计算错误率;
  6. 使用算法:此步骤可以适用于任何监督学习算法,而适用决策树可以更好地理解数据的内在含义。

下表的数据包含5个海洋动物,特征包括:不浮出水面是否可以生存,以及是否有脚蹼。

可以将这些动物分为两类:鱼类和非鱼类

编号 不浮出水面是否可以生存 是否有脚蹼 属于鱼类
1
2
3
4
5

1.1信息增益

划分数据集的大原则是:将无序的数据变得更加有序。

多种方法划分数据集,各有优缺点,可以通过计算信息增益的方式评判,而集合信息的度量方式称为香农熵或熵。

熵定义为信息的期望值,如果待分类的事务可能划分在多个分类之中,则符号xi的信息定义为:(其中p(xi)是选择该分类的概率)
l ( x i ) = l o g 2 p ( x i ) l(x_i)=-log_2p(x_i)
为了计算熵,需要计算所有类别所有可能值包含的信息期望值,通过下方公式:(其中n是分类的数目)
H = i = 1 n p ( x i ) l o g 2 p ( x i ) H=-\sum_{i=1}^np(x_i)log_2p(x_i)
下面给出Python计算信息熵的代码(在Pycharm中新建DT.py):

from math import log

def calcShannonEnt(dataSet):
    numEntries = len(dataSet)
    # 为所有可能分类创建字典
    labelCounts = {}
    for featVec in dataSet:
        currentLabel = featVec[-1]
        if currentLabel not in labelCounts.keys():
            labelCounts[currentLabel] = 0
            labelCounts[currentLabel] += 1
    shannonEnt = 0.0
    # 以2为底求对数
    for key in labelCounts:
        prob = float(labelCounts[key])/numEntries
        shannonEnt -= prob * log(prob,2)
    return shannonEnt

接下来可以写入数据函数createDataSet():

def createDataSet():
    # 上表中的数据集
    dataSet = [[1,1,'yes'],[1,1,'yes'],[1,0,'no'],[0,1,'no'],[0,1,'no']]
    labels = ['no surfacing','flippers']
    return dataSet, labels

然后再Pycharm中写main函数:

if __name__ == '__main__':
    myDat, labels = createDataSet()
    print(myDat)
    print(calcShannonEnt(myDat))

熵越高,则混合的数据也越多,增加第三个名为maybe的分类,测试熵的变化:

    # 在main函数中增加代码
    myDat[0][-1] = 'maybe'
    print(myDat)
    print(calcShannonEnt(myDat))

1.2划分数据集

目前,已经度量数据集的无序程度(测量信息熵),接下来划分数据集:将对每个特征划分数据集的结果计算一次信息熵,然后判断按照哪个特征划分数据集是最好的划分方式。

在DT.py中添加splitDataSet()函数:

# 三个形参:待划分的数据集、划分数据集的特征和特征返回值
def splitDataSet(dataSet, axis, value):
    retDataSet = [] # 创建列表
    # 遍历数据集中每个元素,符合要求的值添加到列表中
    for featVec in dataSet:
        # 用if语句抽取出符合特征的数据
        if featVec[axis] == value:
            reducedFeatVec = featVec[:axis]
            reducedFeatVec.extend(featVec[axis+1:])
            retDataSet.append(reducedFeatVec)
    return retDataSet

然后在main函数中添加:

    print(splitDataSet(myDat, 0, 1))
    print(splitDataSet(myDat, 0, 0))

接下来将遍历整个数据集,循环计算香农熵和splitDataSet()函数,找到最好的特征划分方式。在DT.py中添加chooseBestFeatureToSplit()函数:

# 满足条件:数据必须是一种由列表元素组成的列表,而且所有的列表元素都具有相同的数据长度
# 条件2:数据的最后一列或每个实例的最后一个元素是当前实例的类别标签
def chooseBestFeatureToSplit(dataSet):
    numFeatures = len(dataSet[0]) - 1
    baseEntropy = calcShannonEnt(dataSet)
    bestInfoGain = 0.0; bestFeature = -1
    for i in range(numFeatures):
        featList = [example[i] for example in dataSet]
        # 创建唯一的分类标签列表
        uniqueVals = set(featList)
        newEntropy = 0.0
        for value in uniqueVals:
            # 计算每种划分方式的信息熵
            subDataSet = splitDataSet(dataSet, i, value)
            prob = len(subDataSet)/float(len(dataSet))
            newEntropy += prob * calcShannonEnt(subDataSet)
        infoGain = baseEntropy - newEntropy
        if(infoGain > bestInfoGain):
            # 计算最好的信息增益
            bestInfoGain = infoGain
            bestFeature = i
    return bestFeature

然后在main函数中添加:

    print(chooseBestFeatureToSplit(myDat))

从运行结果看:第0个特征是最好的用于划分数据集的特征,即可以按”不浮出水面是否可以生存“。

1.3递归构建决策树

目前已经学习了从数据集构造决策树算法所需要的子功能模块,其工作原理:得到原始数据集,然后基于最好的属性值划分数据集,由于特征值可能多于两个,因此可能存在两个分支的数据集划分。第一次划分之后,数据将被向下传递到树分支的下一个节点,在这个节点上,可以再次划分数据,很符合递归原则。

递归结束的条件是:程序遍历完所哟划分数据集的属性,或者每个分支下的所有实例都具有相同的分类。如果所有实例具有相同的分类,则得到一个叶子节点或者终止块。任何到达叶子节点的数据必然属于叶子节点的分类。

如果数据集已经处理了所有属性,但是类标签依然不是唯一的,此时需要决定如何定义该叶子节点,在这种情况下,通常会采用多数表决的方法决定该叶子节点的分类。

代码实现:

先在DT.py顶部增加一行代码:import operator, 然后添加majorityCnt()函数:

def majorityCnt(classList):
    classCount = {}   # 创建列表
    # 字典对象存储了每个标签出现的频率
    for vote in classList:
        if vote not in classCount.keys():classCount[vote] = 0
        classCount[vote] += 1
    # 利用operator操作键值排序字典,并返回出现次数最多的分类名称
    sortedClassCount = sorted(classCount.items(),
                              key = operator.itemgetter(1),reverse = True)
    return sortedClassCount[0][0]

接下来继续创建树的函数:

# 输入两个参数:数据集和标签列表。标签列表包含了数据集中所有特征的标签,算法本身并不需要这个变量,但为了给出数据明确的含义,也输入这个参数。
def createTree(dataSet, labels):
    # 创建了名位classList的列表变量
    classList = [example[-1] for example in dataSet]
    # 类别完全相同则停止继续划分
    if classList.count(classList[0]) == len(classList):
        return classList[0]
    # 遍历完所有特征时返回次数最多的
    if len(dataSet[0]) == 1:
        return majorityCnt(classList)
    bestFeat = chooseBestFeatureToSplit(dataSet)
    bestFeatLabel = labels[bestFeat]
    # 创建字典存储树的信息,这对于其后绘制树形图非常重要
    myTree = {bestFeatLabel:{}}
    del(labels[bestFeat])
    # 得到列表包含的所有属性值
    featValues = [example[bestFeat] for example in dataSet]
    uniqueVals = set(featValues)
    # 遍历当前选择特征包含的所有属性值,递归调用,得到的返回值被插入到字典变量myTree中,函数终止时,字典中将会嵌套很多代表叶子节点信息的字典数据。
    for value in uniqueVals:
        # 复制类标签,并存储在新列表subLabels
        subLabels = labels[:]
        myTree[bestFeatLabel][value] = createTree(splitDataSet\
                            (dataSet, bestFeat,value),subLabels)
    return myTree

测试代码:

在main函数中添加:

    print(createTree(myDat,labels))

变量myTree包含了很多代表树结构信息的嵌套字典,第一个关键字 no surfacing 是第一个划分数据集的特征名称,该关键字的值也是另一个数据字典。第二个关键字是 no surfacing 特征划分的数据集,这些关键字的值是 no surfacing 节点的子节点。这些值可能是类标签,有可能是另一个数据字典。如果值是类标签,则该子节点是叶子节点;如果值是另一个数据字典,则子节点是一个判断节点,这个格式结构不断重复构成整棵树。

2.在Python中使用Matplotlib注解绘制树形图

上节已经正确地从数据集中构造树,接下来绘制图形,方便正确理解数据信息。

主要就是绘制如下图的决策树:

2.1Matplotlib注解

Matplotlib提供了一个注解工具annotations,非常有用,它可以在数据图形上添加文本注解。由于数据上面直接存在文本描述非常丑陋,因此工具内嵌支持带箭头的划线工具。恰好可以使用该注解功能绘制树形图

使用文本注解绘制树节点的实现:

在pycharm中新建DTPlotter.py, 写入下列代码:

import matplotlib.pyplot as plt
# 全局设置中文字体,为了输出中文
from pylab import mpl
mpl.rcParams['font.sans-serif'] = ['SimHei'] # 微软雅黑

# 定义文本框和箭头格式
decisionNone = dict(boxstyle = "sawtooth", fc = "0.8")
leafNode = dict(boxstyle = "round4", fc = "0.8")
arrow_args = dict(arrowstyle = "<-")

# 绘制带箭头的注解
def plotNode(nodeTxt, centerrPt, parentPt, nodeType):
    createPlot.ax1.annotate(nodeTxt, xy = parentPt,
                            xycoords = 'axes fraction',
                            xytext = centerrPt,
                            textcoords = 'axes fraction',
                            va = "center",
                            ha = "center",
                            bbox = nodeType,
                            arrowprops = arrow_args)

def createPlot():
    # 创建一个新图形
    fig = plt.figure(1, facecolor = 'white')
    fig.clf()  # 清空绘图区
    createPlot.ax1 = plt.subplot(111, frameon = False)
    # 绘制两个代表不同类型的树节点
    plotNode('决策节点', (0.5, 0.1), (0.1, 0.5), decisionNone)
    plotNode('叶节点', (0.8, 0.1), (0.3, 0.8), leafNode)
    plt.show()  # 输出显示图形
    
if __name__ == '__main__':
    print(createPlot()) # 使用文本注解绘制树节点

在pycharm中ctrl+shift+F10,运行DTPlotter.py,输出图形:

2.2构建注解树

绘制一棵完整的树需要一些技巧,虽然有了x, y坐标,但如何放置所有的树节点却是个问题。因此必须知道有多少个叶节点,以便可以正确确定x轴的长度;还需要知道树有多少层,以便可以正确确定y轴的高度。故继续定义两个函数getNumLeafs()和getTreeDepth(),来获取叶节点的数目和树的层数。

继续在DTPlotter.py, 写入下列代码:

def getNumLeafs(myTree):
    numLeafs = 0
    # 第一个关键字是第一次划分数据集的类别标签
    firstStr = list(myTree.keys())[0]
    secondDict = myTree[firstStr]
    # 测试节点的数据类型是否为字典
    # 如果子节点是字典类型,则该节点也是一个判断节点,需要递归调用getNunLeafs()函数
    # 遍历整棵树,累计叶子节点的个数,并返回该数值
    for key in secondDict.keys():
        if type(secondDict[key]).__name__=='dict':
            numLeafs += getNumLeafs(secondDict[key])
        else:
            numLeafs += 1
    return numLeafs

# 和getNumLeafs()函数有点相似
def getTreeDepth(myTree):
    maxDepth = 0
    firstStr = list(myTree.keys())[0]
    secondDict = myTree[firstStr]
    for key in secondDict.keys():
        if type(secondDict[key]).__name__=='dict':
            thisDepth = 1 + getTreeDepth(secondDict[key])
        else:
            thisDepth = 1
        if thisDepth > maxDepth: maxDepth = thisDepth
    return maxDepth

# 预先存储树信息,避免每次测试代码都从数据中创建树的麻烦
def retrieveTree(i):
    listOfTrees = [ {'no surfacing': {0: 'no', 1: {'flippers':
                    {0: 'no', 1: 'yes'}}}},
                    {'no surfacing': {0: 'no', 1: {'flippers':
                    {0: {'head':{0:'no',1:'yes'}}, 1: 'no'}}}} ]
    return listOfTrees[i]

而在main函数中添加:( 同时注释掉print(createPlot()) )

if __name__ == '__main__':
    # print(createPlot()) # 使用文本注解绘制树节点
    print(retrieveTree(1))
    myTree = retrieveTree(0)
    print(getNumLeafs(myTree))
    print(getTreeDepth(myTree))

输出结果:

函数retrieveTree()主要用于测试,返回预定义的树结构,调用getNumLeafs()函数返回3,等于树0的叶子节点数;调用getTreeDepths()函数也能够正确返回数的层数。

但输出没有绘制一棵完整的树,尽管已经定义了createPlot()函数,但需要更新这部分代码,把树信息传进去这个函数。

更新createPlot()函数,并新增plotMidText()函数和plotTree()函数:

# 作用是计算tree的中间位置
# cntrpt起始位置,parentpt终止位置,txtstrin文本标签信息
def plotMidText(cntrPt, parentPt, txtString):
    # 找到x和y的中间位置
    xMid = (parentPt[0] - cntrPt[0]/2.0 + cntrPt[0])
    yMid = (parentPt[1] - cntrPt[1]/2.0 + cntrPt[1])
    createPlot.ax1.text(xMid, yMid, txtString)

def plotTree(myTree, parentPt, nodeTxt):
    numLeafs = getNumLeafs(myTree)
    depth = getTreeDepth(myTree)
    firstStr = list(myTree.keys())[0]
    # 计算子节点的坐标
    cntrPt = (plotTree.xOff + (1.0 + float(numLeafs))/2.0/plotTree.totalW, plotTree.yOff)
    plotMidText(cntrPt, parentPt, nodeTxt)  # 绘制线上的文字
    plotNode(firstStr, cntrPt, parentPt, decisionNone)  # 绘制节点
    secondDict = myTree[firstStr]
    # 每绘制一次图,将y的坐标减少1.0/plottree.totald,间接保证y坐标上深度的
    plotTree.yOff = plotTree.yOff - 1.0/plotTree.totalD
    for key in secondDict.keys():
        if type(secondDict[key]).__name__=='dict':
            plotTree(secondDict[key], cntrPt, str(key))
        else:
            plotTree.xOff = plotTree.xOff + 1.0/plotTree.totalW
            plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode)
            plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key))
    plotTree.yOff = plotTree.yOff + 1.0/plotTree.totalD
    
def createPlot(inTree):
    # 创建一个新图形
    fig = plt.figure(1, facecolor = 'white')
    fig.clf()   # 清空绘图区
    axprops = dict(xticks = [], yticks = [])
    # subplot为定义了一个绘图,111表示figure中的图有1行1列,即1个,最后的1代表第一个图
    # frameon表示是否绘制坐标轴矩形
    createPlot.ax1 = plt.subplot(111, frameon = False, **axprops)
    plotTree.totalW = float(getNumLeafs(inTree))
    plotTree.totalD = float(getTreeDepth(inTree))
    plotTree.xOff = -0.5/plotTree.totalW
    plotTree.yOff = 1.0
    plotTree(inTree, (0.5, 1.0), '')
    plt.show()
# xOff和yOff用来记录当前要画的叶子节点的位置
# cntrPt记录当前要画的树的树根的结点位置
# x轴和y轴的范围都是[0.0~1.0],输出的图形是按比例绘制树形图,不担心变形,不建议用像素为单位绘制图形

接下来更新main函数:

if __name__ == '__main__':
    # print(createPlot()) # 使用文本注解绘制树节点
    # print(retrieveTree(1))
    # myTree = retrieveTree(0)
    # print(getNumLeafs(myTree))
    # print(getTreeDepth(myTree))
    myTree = retrieveTree(1)
    print(createPlot(myTree))

输出结果:

3.测试和存储分类器

接下来在真实数据上使用决策树分类算法,验证决策树是否可以正确预测患者应该使用的隐形眼镜类型。

3.1测试算法:使用决策树执行分类

在执行数据分类时,需要决策树以及用于构造树的标签向量。然后,程序比较测试数据与决策树上的数值,递归执行该过程直到进入叶子节点;最后将测试数据定义为叶子节点所属的类型。故继续在DT.py中添加classify()函数:

# 存储带有特征的数据会面临一个问题:程序无法确定特征在数据集中的位置
# 使用index方法查找当前列表中第一个匹配firstStr变量的元素
def classify(inputTree, featLabels, testVec):
    firstStr = list(inputTree.keys())[0]
    secondDict = inputTree[firstStr]
    # 将标签字符串转换为索引
    featIndex = featLabels.index(firstStr)
    for key in secondDict.keys():
        if testVec[featIndex] == key:
            if type(secondDict[key]).__name__=='dict':
                classLabel = classify(secondDict[key], featLabels, testVec)
            else:
                classLabel = secondDict[key]
    return classLabel    # 返回遍历后当前节点的分类标签

接下来在main函数中调用函数:

但是发现了尴尬的问题,我之前在DT.py中写了main函数,在DTPlotter.py中也写了main函数,但测试算法时,我需要调用DT.py中的函数,也需要调用DTPlotter.py的函数,与其把DT.py导入到DTPlotter.py中,我在想还不如干脆想c++一样单独建一个main函数,导入DT.py和DTPlotter.py(类比c++导入头文件),试了果然可以。(果然编程能力也是在实战中进步的。我还真是个憨憨,尴尬)

接下来单独新建main.py:

import DT
import DTPlotter

if __name__ == '__main__':
    myDat, labels = DT.createDataSet()
    print(labels)
    myTree = DTPlotter.retrieveTree(0)
    print(myTree)
    print(DT.classify(myTree, labels, [1,0]))
    print(DT.classify(myTree, labels, [1,1]))

运行,输出结果:

DT11.jpg

输出结果与上节输出结果比较:第一节点名为:no surfacing,它有两个子节点:一个是名字为0的叶子节点,类标签为no;另一个是名为flippers的判断节点,此处进入递归调用,flippers节点有两个子节点。

3.2使用算法:决策树的存储

构造决策树是很耗时的任务,即使处理很小的数据集,如前面的样本数据,也要花费几秒的时间,如果数据集很大,将会耗费很多计算时间。所以为了节省计算时间,最好能够在每次执行分类时调用已经构造好的决策树。为了解决这个问题,需要使用Python模块pickle序列化对象,序列化对象可以在磁盘上保存对象,并在需要的时候读取出来,任何时候都可以执行序列化操作,字典对象也不列外。

在DT.py中添加:

# 使用pickle模块存储决策树
def storeTree(inputTree, filename):
    import pickle
    fw = open(filename, 'wb') # python2是 w,python3是 wb
    pickle.dump(inputTree, fw)
    fw.close()

def grabTree(filename):
    import pickle
    fr = open(filename,'rb')  # 对应的打开时要用 rb
    return pickle.load(fr)

然后在main.py中添加:

    DT.storeTree(myTree, 'classifierStorage.txt')
    DT.grabTree('classifierStorage.txt')

输出结果:

DT12.jpg

而且生成了classifierStorage.txt文件,这样,就不用每次对数据分类时重新学习一遍,这也是决策树的优点之一。

4.示例:使用决策树预测隐形眼镜类型

案例:眼科医生是如何判断患者需要佩戴的镜片类型;一旦理解了决策树的工作原理,也可以帮助人们判断需要佩戴的镜片类型。

示例:使用决策树预测隐形眼镜类型

  1. 收集数据:提供的文本文件
  2. 准备数据:解析tab键分隔的数据行
  3. 分析数据:快速检查数据,确保正确地解析数据内容,使用createPlot()函数绘制最终的树形图。
  4. 训练算法:使用createTree()函数
  5. 测试算法:编写测试函数验证决策树可以正确分类给定的数据实例
  6. 使用算法:存储树的数据结构,以便下次使用时无需重新构建树

隐形眼镜数据集是非常著名的数据集,它包含很多患者眼部状况的观察条件以及医生推荐的隐形眼镜类型。隐形眼镜类型包括硬材质、软材质以及不适合佩戴隐形眼镜。(数据来源于:UCI机器学习存储库 https://archive.ics.uci.edu/ml/index.php )

继续在main.py中添加代码:

    fr=open('lenses.txt')
    lenses=[inst.strip().split('\t') for inst in fr.readlines()]
    lensesLabels=['age','prescript','astigmatic','tearRate']
    lensesTree = DT.createTree(lenses,lensesLabels)
    print(lensesTree)
    DTPlotter.createPlot(lensesTree)

输出结果:(树形图构造好像不对,尴尬)

有空仔细查看哪个函数写错了。

这次使用的算法称为ID3,它是一个好的算法但并不完美,之后会学另一个决策树构造算法CART。ID3算法无法直接处理数值型数据,尽管可以通过量化的方法将数值型数据转化为标称型数值,但如果存在太多的特征划分,ID3算法仍然会面临其它问题。

5.总结

决策树分类器就像带有终止块的流程图,终止块表示分类结果。开始处理数据集时,开始处理数据集时,需要先测量集合中数据的不一致性,也就是熵,然后寻找最优方案划分数据集,直到数据集中的所有数据属于同一分类。ID3算法可以用于划分标称型数据集。构建决策树时,通常采用递归的方法将数据集转化为决策树。一般不构造新的数据结构,而是使用Python语言内嵌的数据结构字典存储树节点信息。

使用Matplotlib的注解功能,可以将存储的树结构转化为容易理解的图形。Python语言的pickle模块可用于存储决策树的结构。隐形眼镜的例子表明决策树可能会产生过多的数据集划分,从而产生过度匹配数据集的问题。可以通过裁剪决策树,合并相邻的无法产生大量信息增益的叶节点,消除过度匹配问题。

还有其它决策树的构造算法,最流行的是C4.5和CART。

参考文献:

  1. 《机器学习实战》-k近邻算法

革命尚未成功,同志仍需努力!

发布了61 篇原创文章 · 获赞 7 · 访问量 1万+

猜你喜欢

转载自blog.csdn.net/MRZHUGH/article/details/102946077
今日推荐