手撸决策树代码——原理详解(2)(python3)

第四步:递归创建字典树

构建决策字典树用到的最基本的思想是递归
在构建过程中:我们需要用到第一步和第三步的函数,通过第三步得到的最好的划分方式不断的作为当前树的根标签,并将第一步划分的子数据集作为下层使用,不断递归
这个递归有两个结束条件,写在了代码注释下

def createTree(dataSet, labels):
    classList = [example[-1] for example in dataSet]  # 提取最后一列的判断结果,返回列表
    # 递归结束条件一:当发现到叶子结点,也就是所有的判断特征都用完了,判断结果都一致
    if classList.count(classList[0]) == len(classList):  # 检查是不是所有判断结果都是一样的,一样的那么决策树就只有一个根节点和一个子节点
        return classList[0]
    # 递归结束条件二:当发现到叶子结点,判断特征用完后,判断结果仍然有分歧,我们将判断结果较多的作为结果返回
    # 递归条件二要结合第一个条件判断,既然能到第二个条件,说明判断结果有不一样的
    # 而且再检查它还有几个属性,发现只剩下一个属性了,不能将判断结果划分了
    if len(dataSet[0]) == 1:  # 一般前面用不到,递归到最底层的叶子结点时执行
        return majorityCnt(classList)
    bestFeat = chooseBestFeatureToSplit(dataSet)  # 返回最佳的属性划分方式
    bestFeatLabel = labels[bestFeat]  # 并得到该属性的字符串名
    # 由于{}设置为空,即键的值为空,所以下层字典树的创建在该{}里创建
    myTree = {bestFeatLabel: {}}  # 开始创建字典树,bestFeatLabel为根结点标签,下一层的标签放在后面的括号内
    del (labels[bestFeat])  # 删除改属性的字符串名
    featValues = [example[bestFeat] for example in dataSet]  # 把这个属性的所有特征拿出来
    uniqueVals = set(featValues)  # 删除重复的特征
    for value in uniqueVals:
        subLabels = labels[:]  # 创建新的子标签
        # split函数获取去除了属性bestFeat下value的子数据集并获得子标签,递归创建子树
        myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet, bestFeat, value), subLabels)
    return myTree

第四步的子步: 优化

并不是每次数据集我们都可以完整一个不差划完的,当我们发现我们可以用的判断条件都用完了后,判断结果仍出现分歧,此时就要优化。
例如:假设我们划分好人坏人,特征用的是知识丰富,有钱,活泼开朗,家庭状况,是否有配偶,其结果发现,最后的判断结果仍然区分不出好人坏人时,我们将数据传入该函数,该函数将判断结果数量多的那一方传给我们作为最后的结果。

# 针对所有特征都用完,但是最后一个特征中类别还是存在很大差异,
# 比如西瓜颜色为青绿的情况下同时存在好瓜和坏瓜,无法进行划分,此时选取该类别中最多的类
def majorityCnt(classlist):  # 作为划分的返回值,majorityCnt的作用就是找到类别最多的一个作为返回值
    classCount = {}
    for vote in classlist:  # 寻找classlist里的值将存入字典,并记录该值在classlist里出现的次数
        if vote not in classCount.keys(): classCount[vote] = 0
        classCount[vote] += 1
        # 将classcount里的值进行排序,大的在前
    sortedClassCount = sorted(classCount.items(), key=operator.itemgetter(1), reverse=True)

    return classCount[0][0]  # 返回最大值

第五步 做画出树图像的准备工作

5.1定义结点并定义结点和箭头绘制函数

运用matplotlib工具,我们对字典树进行绘制和注解,首先我们要将根结点和叶子结点进行区分。

import matplotlib

matplotlib.use('TkAgg')
import matplotlib.pyplot as plt

#  定义两种结点类型
decisionNode = dict(boxstyle="sawtooth", fc="0.8")
#  boxstyle为文本框的类型,sawtooth是锯齿形,ec是边框线颜色,edgecolor,ew为edgewidth
leafNode = dict(boxstyle="round4", fc="0.8")
#  定义叶节点,round4是方框椭圆型
arrow_args = dict(arrowstyle="<-")
#  定义箭头方向 与常规的方向相反
#  Advanced Raster Graphics System高级光栅图形系统

# 声明绘制一个节点的函数
'''
annotate是关于一个数据点的文本 相当于注释的作用 
nodeTxt:即为文本框或锯齿形里面的文本内容
nodeType:是判断节点还是叶子节点
bbox给标题增加外框
nodeTxt为要显示的文本
centerPt为文本的中心点,箭头所在的点
parentPt为指向文本的点  pt为point

'''


def plotNode(nodeTxt, centerPt, parentPt, nodeType):  # 输入4个参数:结点文字,终点,起点,结点的类型
    createPlot.ax1.annotate(nodeTxt, xy=parentPt, xycoords='axes fraction', xytext=centerPt, textcoords='axes fraction',
                            va="center", ha="center", bbox=nodeType,
                            arrowprops=arrow_args)
5.2在结点之间填充属性的特征的文本
#  在父子结点的箭头的中间处填充文本
def plotMidText(cntrPt, parentPt, txtString):  # 分别输入终点,起点,和文本内容
    # 找到输入文本的位置,即终点和起点连线的中点处
    xMid = (parentPt[0] + cntrPt[0]) / 2.0
    yMid = (parentPt[1] + cntrPt[1]) / 2.0
    createPlot.ax1.text(xMid, yMid, txtString)  # 在(xMid,yMid)位置填充txtString文本
5.3获取该字典树的深度和叶子结点个数
#  获取字典树的叶子结点个数
def getNumLeafs(myTree):
    numLeafs = 0  # 定义记录叶节点的数目的变量
    firstside = list(myTree.keys())  # 我们获取当前输入的树的所有键,并将其转换成列表
    firstStr = firstside[0]  # 并把当前列表第一个结点(当前树的根节点)的键获取
    secondDict = myTree[firstStr]  # 去输入的字典树里找这个键对应的值,存入另一个空字典
    for key in secondDict.keys():  # 查找存入这个字典的值它是不是字典类型;是说明下面还有分支,不是说明是叶子结点
        if type(secondDict[key]) == dict:
            numLeafs += getNumLeafs(secondDict[key])  # 是字典类型就递归找子节点是不是叶子结点
        else:
            numLeafs += 1  # 不是字典类型说明是叶子结点,数量加一,并返回上一层
    return numLeafs

#  获取字典树的深度
def getTreeDepth(myTree):
    maxDepth = 0
    firstside = list(myTree.keys())   # 我们获取当前输入的树的所有键,并将其转换成列表
    firstStr = firstside[0]  # 并把当前列表第一个结点(当前树的根节点)的键获取
    secondDict = myTree[firstStr]  # 去输入的字典树里找这个键对应的值,存入另一个空字典
    for key in secondDict.keys():  # 查找存入这个字典的值它是不是字典类型;是说明下面还有分支,不是说明下面没有深度可寻
        if type(secondDict[key]) == dict:
            thisDepth = 1 + getTreeDepth(secondDict[key])  # 是字典类型,继续寻找下面分支的深度,并将当前深度记录加一
        else:
            thisDepth = 1  # 如果刚开始就只有根节点就返回深度一,如果后面递归到这里,发现不是字典类型,返回的1值没有用,意义是使下面比较时保持返回的maxdepth不变,
        if thisDepth > maxDepth: maxDepth = thisDepth  # 每层都比较更新一下树的最大深度,并返回上层
    return maxDepth

自此,我们做了绘制字典树的所有准备工作,接下来我们可以开始绘制字典树了。

发布了19 篇原创文章 · 获赞 4 · 访问量 497

猜你喜欢

转载自blog.csdn.net/qq_35050438/article/details/103498121