菜鸟学《机器学习实战》(二)——决策树

决策树:和k近邻算法一样,也是用来做分类用的。简单地说就是将一个未知分类的事务依次通过多个判断条件,根据符不符合条件来进行分类,递归下降,直至判断出确定类型。PS.偷一张书上邮件分类的图嘻嘻....


决策树特性(本菜鸡认为的重点):

优点:复杂度不高,易于理解,中间值缺失不敏感(就是缺几个值也问题不大)

缺点:可能过度匹配

使用范围:数值型、标称型。

1、知识补充——信息增益

    根据香农大佬钦定的基本法,符号xi的信息定义为


其中p(xi)是选择该分类的概率

信息期望值为


2、梳理实战代码

好了,有了信息熵的知识铺垫,我们可以开始了

首先创建trees.py文件

import operator
from math import log  # 通过math包导入log函数


# 计算给定数据的熵
def calcShannonEnt(dataSet):
    numEntries = len(dataSet)  # 数据集长度
    labelCounts = {}  # 定义标签字典
    for featVec in dataSet:
        currentLabel = featVec[-1]  # 最后一列数值,即标签数值
        if currentLabel not in labelCounts.keys():  # 若不存在
            labelCounts[currentLabel] = 0  # 添加入标签字典,并初始化为0
        labelCounts[currentLabel] += 1  # 标签频数自增
    shannonEnt = 0.0  # 定义香农熵,注意是浮点型,PS entropy 熵
    for key in labelCounts:  # 计算每个标签的香农熵
        prob = float(labelCounts[key]) / numEntries  # 计算频率
        shannonEnt -= prob * log(prob, 2)  # 代入之前的信息期望值公式
    return shannonEnt  # 返回香农熵


def createDataSet():
    # 数据集
    dataSet = [[1, 1, 'yes'],
               [1, 1, 'yes'],
               [1, 0, 'no'],
               [0, 1, 'no'],
               [0, 1, 'no']]
    # 标签集
    labels = ['no surfacing', 'flippers']
    return dataSet, labels

然后在命令行输入:

>>> from imp import reload
>>> import trees
>>> reload(trees)
<module 'trees' from 'C:\\Users\\dell\\PycharmProjects\\untitled\\trees.py'>
>>> myDat,labels=trees.createDataSet()
>>> myDat
[[1, 1, 'yes'], [1, 1, 'yes'], [1, 0, 'no'], [0, 1, 'no'], [0, 1, 'no']]
>>> trees.calcShannonEnt(myDat)

0.9709505944546686

结果正确,很开森!

决策树分类算法除了需要测量信息熵,还需要划分数据集,度量划分数据集的熵,以判断当前是否正确地划分了数据集。

好了,用正常人听得懂的话解释就是:在分类的过程中,将哪个属性作为新的判断结点,取决于这个属性是否为当前“最好的”分类标准。“最好”即用信息增益最大,即按照此种分类计算(之前的代码干的活)出的信息熵是最大的。

下面接着在trees.py中继续书写代码。

#参数解释
#dataSet 待划分的数据集
#axis 划分数据集的特征
#value 特征的返回值
def splitDataSet(dataSet, axis, value):
    # python的引用特性决定了不创建新的对象则在函数内部对列表对象的修改,将会影响该列表对象的整个生存周期
    retDataSet = []
    for featVec in dataSet:
        if featVec[axis] == value:#一旦在下标为axis的属性发现符合要求的值,即等于value,则将其添加到新创建的列表retDataSet中
            #python通过下标子序列中[a,b]实际上只取出下标为a到下标为b-1的所有元素
            #换句话说,featVec[:axis]就漏掉了featVec中下表为axis的元素
            reducedFeatVec = featVec[:axis]
            reducedFeatVec.extend(featVec[axis + 1:])#extend是将所有新元素直接加入序列中
            retDataSet.append(reducedFeatVec)#append是将所有新元素视作一个集合,然后再整体作为一个元素加入到序列中
    #返回决策树
    return retDataSet

接着,我们在命令行中输入:

>>> from imp import reload
>>> import trees
>>> reload(trees)
<module 'trees' from 'C:\\Users\\dell\\PycharmProjects\\untitled\\trees.py'>
>>> myDat,labels=trees.createDataSet()
>>> myDat
[[1, 1, 'yes'], [1, 1, 'yes'], [1, 0, 'no'], [0, 1, 'no'], [0, 1, 'no']]
>>> trees.splitDataSet(myDat,0,1)
[[1, 'yes'], [1, 'yes'], [0, 'no']]
>>> trees.splitDataSet(myDat,0,0)
[[1, 'no'], [1, 'no']]

下面,我们需要选择最好的数据集划分方式,即确定之前所说的分类结点

# 选择最好的数据集划分方式
# 选取特征,划分数据集,计算得出最好的划分数据集的特征
# dataSet要求:1、由列表组成的列表,而且所有的列表元素都要有相同的数据长度;2、数据的最后一列是实例的类别标签
def chooseBestFeatureToSplit(dataSet):
    numFeatures = len(dataSet[0]) - 1  # 取除了最后一个元素的前面其他所有属性
    baseEntropy = calcShannonEnt(dataSet)  # 香农熵
    bestInfoGain = 0.0  # 信息增益
    bestFeature = -1  # 最好分类标准的列表下标
    for i in range(numFeatures):
        # 通过set创建唯一的分类标签列表
        # 这种优秀列表创建方式可以学习一波
        # 特别注意,dataSet是由列表组成的列表,所以是example[i]而不是example,取出的实际上是一整个属性列
        featList = [example[i] for example in dataSet]
        uniqueVals = set(featList)  # 集合类属性唯一
        newEntropy = 0.0  # 浮点数注意
        # 计算每种划分方式的信息熵
        for value in uniqueVals:
            subDataSet = splitDataSet(dataSet, i, value)
            prob = len(subDataSet) / float(len(dataSet))
            newEntropy += prob * calcShannonEnt(subDataSet)
        infoGain = baseEntropy - newEntropy
        # 计算最好的信息增益
        if (infoGain > bestInfoGain):
            bestInfoGain = infoGain
            bestFeature = i
    #返回最优划分属性
    return bestFeature

接着在命令行中测试输入:

>>> reload(trees)
<module 'trees' from 'C:\\Users\\dell\\PycharmProjects\\untitled\\trees.py'>
>>> myDat,labels=trees.createDataSet()
>>> trees.chooseBestFeatureToSplit(myDat)
0
>>> myDat

[[1, 1, 'yes'], [1, 1, 'yes'], [1, 0, 'no'], [0, 1, 'no'], [0, 1, 'no']]

结果正确,开心!

加下来来学习构建递归构建决策树

首先定义函数majority,和之前k近邻算法中的classify0函数类似,不再赘述。

# classList 分类名称列表
# 可以参看之前k近邻算法中的classify0函数,就不再解释了
def majorityCnt(classList):
    classCount = {}
    for vote in classList:
        if vote not in classCount.keys(): classCount[vote] = 0;
        classCount[vote] += 1
    sortedClassCount = sorted(classCount.items(),
                              key=operator.itemgetter(1), reverse=True)
    return sortedClassCount[0][0]
接着正式递归构建决策树:

递归终止条件(满足其一即可):

(1)程序遍历完所有划分数据集的属性

(2)每个分支下的所有实例都具有相同的分类

接着输入代码

def createTree(dataSet, labels):
    # 最后一个是标签
    classList = [example[-1] for example in dataSet]
    # 类别完全相同则停止继续划分
    if classList.count(classList[0]) == len(classList):
        return classList[0]
    # 遍历玩所有特征时返回出现次数最多的
    if len(dataSet[0]) == 1:
        return majorityCnt(classList)
    bestFeat = chooseBestFeatureToSplit(dataSet)  # 找到最合适的分割
    bestFeatLabel = labels[bestFeat]
    myTree = {bestFeatLabel: {}}  # 开始构造决策树
    del (labels[bestFeat])  # 在原有标签中删除最佳分割属性
    featValues = [example[bestFeat] for example in dataSet]
    uniqueVals = set(featValues)  # 通过集合删除重复元素
    for value in uniqueVals:
        subLabels = labels[:]  # 拷贝剩下的标签数据
        myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet, bestFeat, value), subLabels)  # 此处用了递归的思想
    # 返回生成的生成树
    return myTree

在命令行中测试:

>>> from imp import reload
>>> import trees
>>> reload(trees)
<module 'trees' from 'C:\\Users\\dell\\PycharmProjects\\untitled\\trees.py'>
>>> myDat,labels=trees.createDataSet()
>>> myTree=trees.createTree(myDat,labels)
>>> myTree
{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}}

结果正确

接下来,我们来学习使用Matplotlib绘制树形图

创建treePlotter.py文件用于绘图,继续书写python代码 PS. plotter 绘图机

import matplotlib.pyplot as plt#引入matplotlib包中的pyplot

#判断结点
decisionNode = dict(boxstyle="sawtooth", fc="0.8")#boxstyle 文本框类型 sawtooth 锯齿形 fc 边框线粗细
#叶子节点
leafNode = dict(boxstyle="round4", fc="0.8")#round4 圆边
#箭头
arrow_args = dict(arrowstyle="<-")#arrowstyle 箭头风格
def plotNode(nodeTxt, centerPt, parentPt, nodeType):
    # annotate是关于一个数据点的文本
    # nodeTxt为要显示的文本,centerPt为文本的中心点,箭头所在的点,parentPt为指向文本的点
    createPlot.ax1.annotate(nodeTxt, xy=parentPt,  xycoords='axes fraction',\
             xytext=centerPt, textcoords='axes fraction',\
             va="center", ha="center", bbox=nodeType, arrowprops=arrow_args)
def createPlot():
   fig = plt.figure(1, facecolor='white')#背景为白色
   fig.clf()# 把画布清空
   # createPlot.ax1为全局变量,绘制图像的句柄,subplot为定义了一个绘图,111表示figure中的图有1行1列,即1个,最后的1代表第一个图
   createPlot.ax1 = plt.subplot(111, frameon=False)
   #绘制判断结点 名字 箭头 箭尾 结点类型

接着我们在命令行中测试我们的绘图代码

>>> import treePlotter
>>> treePlotter.createPlot()

绘图结果如下


结果正确

接着,我们来构造注解树

首先我们要知道树有多少个叶子节点,以确定x轴的长度;还需要知道树有多少层,以确定y轴的高度。

在treePlotter.py中,定义getNumLeafs()获取叶子数目,定义getTreeDepth()获取树的层数,来,在treePlotter.py中接着敲代码如下:

#获取树的层数
def getTreeDepth(myTree):
    maxDepth = 0
    #python2 版本
    # firstStr = myTree.keys()[0]
    #python3 版本
    firstSides=list(myTree.keys())
    firstStr=firstSides[0]

    secondDict = myTree[firstStr]
    for key in secondDict.keys():
        if type(secondDict[key]).__name__ == 'dict':
            thisDepth = 1 + getTreeDepth(secondDict[key])
        else:#叶子节点在此处即条件为假,终止递归并直接返回,并将计算树深度的变量加一
            thisDepth = 1
        if thisDepth > maxDepth: maxDepth = thisDepth
    return maxDepth

#为节省时间,函数retrieveTrees输出预先存储的树信息,避免了每次测试代码时都要从数据中创建树的麻烦
#retrieve 取回,恢复
def retrieveTree(i):
    listOfTrees=[{'no surfacing':{0:'no',1:{'flippers':{0:'no',1:'yes'}}}},
                 {'no surfacing':{0:'no',1:{'flippers':{0:{'head':{0:'no',1:'yes'}},1:'no'}}}}
                 ]
    return listOfTrees[i]

接着在命令行中检测:

>>> reload(treePlotter)
<module 'treePlotter' from 'C:\\Users\\dell\\PycharmProjects\\untitled\\treePlotter.
py'>
>>> treePlotter.retrieveTree(1)
{'no surfacing': {0: 'no', 1: {'flippers': {0: {'head': {0: 'no', 1: 'yes'}}, 1: 'no
'}}}}
>>> myTree=treePlotter.retrieveTree(0)
>>> treePlotter.getNumLeafs(myTree)
3
>>> treePlotter.getTreeDepth(myTree)

2

现在,我们把之前的方法组合到一起,绘制一棵完整的树。现在继续完善treePlotter.py,并更新其中的createPlot()函数

在treePlotter中继续输入

# 在父子节点间填充文本信息
# 计算父节点和子节点的中间位置,并添加简单的文本标签信息
def plotMidText(cntrPt, parentPt, txtString):
    xMid = (parentPt[0] - cntrPt[0]) / 2.0 + cntrPt[0]
    yMid = (parentPt[1] - cntrPt[1]) / 2.0 + cntrPt[1]
    createPlot.ax1.text(xMid, yMid, txtString)  # 是ax1是一不是L


def plotTree(myTree, parentPt, nodeTxt):
    numLeafs = getNumLeafs(myTree)  # 宽
    depth = getTreeDepth(myTree)  # 高
    # python2版本
    # firstStr = myTree.keys()[0]
    # python3版本
    firstSides = list(myTree.keys())
    firstStr = firstSides[0]
    # totalW存储树的宽度
    # totalD存储树的深度
    cntrPt = (plotTree.xOff + (1.0 + float(numLeafs)) / 2.0 / plotTree.totalW, plotTree.yOff)
    # 标记子节点属性值
    # 计算父节点和子节点的中间位置,并添加简单的文本标签信息
    plotMidText(cntrPt, parentPt, nodeTxt)
    plotNode(firstStr, cntrPt, parentPt, decisionNode)
    secondDict = myTree[firstStr]
    # 减少y偏移
    # 按比例减少全局变量plotTree.yOff,并标注此处将要绘制子节点,既可以是叶子也可以是判断
    # 因为我们是自顶向下绘制图形,因此需要依次减少y坐标值,而不是递增y坐标值
    plotTree.yOff = plotTree.yOff - 1.0 / plotTree.totalD
    for key in secondDict.keys():
        if type(secondDict[key]).__name__ == 'dict':  # 不是叶子节点继续递归
            plotTree(secondDict[key], cntrPt, str(key))  # 递归遍历整棵树
        else:  # 是叶子节点画出叶子节点
            plotTree.xOff = plotTree.xOff + 1.0 / plotTree.totalW
            plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode)
            plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key))
    plotTree.yOff = plotTree.yOff + 1.0 / plotTree.totalD


# 创建绘图区,计算树形图全局尺寸,并递归调用函数plotTree()
def createPlot(inTree):
    fig = plt.figure(1, facecolor='white')  # 背景为白色
    fig.clf()  # 把画布清空
    axprops = dict(xticks=[], yticks=[])
    # createPlot.ax1为全局变量,绘制图像的句柄,subplot为定义了一个绘图,111表示figure中的图有1行1列,即1个,最后的1代表第一个图
    createPlot.ax1 = plt.subplot(111, frameon=False, **axprops)
    plotTree.totalW = float(getNumLeafs(inTree))
    plotTree.totalD = float(getTreeDepth(inTree))
    plotTree.xOff = -0.5 / plotTree.totalW;
    plotTree.yOff = 1.0;
    plotTree(inTree, (0.5, 1.0), '')
    # 展示图画
    plt.show()

命令行检测图形:

>>> reload(treePlotter)
<module 'treePlotter' from 'C:\\Users\\dell\\PycharmProjects\\untitled\\treePlotter.
py'>
>>> myTree=treePlotter.retrieveTree(0)

>>> treePlotter.createPlot(myTree)

图案如下:


接着,测试和存储分类器,把重心转移到如何利用决策树执行数据分类

将以下代码加入到trees.py中

def classify(inputTree, featLabels, testVec):
    firstSides = list(inputTree.keys())
    firstStr = firstSides[0]
    secondDict = inputTree[firstStr]
    # 查找当前列表中第一个匹配firstStr变量的元素
    featIndex = featLabels.index(firstStr)
    for key in secondDict.keys():
        if testVec[featIndex] == key:
            if type(secondDict[key]).__name__ == 'dict':
                #递归遍历整棵树
                classLabel = classify(secondDict[key], featLabels, testVec)
            else:
                classLabel = secondDict[key]
    return classLabel

下面接着套路般的命令行检测:

>>> reload(trees)
<module 'trees' from 'C:\\Users\\dell\\PycharmProjects\\untitled\\trees.py'>
>>> myDat,labels=trees.createDataSet()
>>> labels
['no surfacing', 'flippers']
>>> myTree=treePlotter.retrieveTree(0)
>>> myTree
{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}}
>>> trees.classify(myTree,labels,[1,0])
'no'
>>> trees.classify(myTree,labels,[1,1])

'yes'

检测结果正确

感谢阅读到最后,发现错误欢迎指正。

猜你喜欢

转载自blog.csdn.net/Zjhao666/article/details/80469680