MLiA笔记_treeplotter

#-*-coding:utf-8-*-
# 3.5 使用文本注解绘制树节点
import matplotlib.pyplot as plt

# 代码定义树节点格式的常量
decisionNode = dict(boxstyle = "sawtooth", fc = "0.8")
leafNode = dict(boxstyle="round4", fc = "0.8")
arrow_args = dict(arrowstyle="<-")


# 然后定义plorNode()函数执行了实际的绘图功能,该函数需要一个绘图区,该区域由全局变量createPlot.ax1定义。
# python语言中所有的变量默认都是全局有效的
def plotNode(nodeTxt, centerPt, parentPt, nodeType):
    createPlot.axl.annotate(nodeTxt, xy=parentPt, xycoords='axes fraction',
                            xytext = centerPt, textcoords='axes fraction',
                            va = "center", ha = "center", bbox = nodeType, arrowprops = arrow_args)

# 最后定义createPlot()函数,首先创建了一个新图形并清空绘图去,然后再绘图区上绘制两个代表不同类型的树节点,后面用这两个树节点绘制树形图
def createPlot():
    fig = plt.figure(1, facecolor="white")
    fig.clf()
    createPlot.axl = plt.subplot(111, frameon = False)
    plotNode('a decision node', (0.5,0.1),(0.1,0.5), decisionNode)
    plotNode('a leaf node',(0.8,0.1),(0.3,0.8),leafNode)
    plt.show()


# 3.6 获取叶节点的数目和树的层数

# getNumLeafs()和getTreeDepth()具有相同的结构.这里使用的数据结构说明了如何在python字典类型中存储树信息。
# 第一个关键字是第一次划分数据集的类别标签, 附带的数值表示子节点的取值。
# 从第一个关键字出发我们可以遍历整棵树的所有子节点,使用python提供的type()函数可以判断子节点是否为字典类型。
    # 如果子节点是字典类型,则该节点也是一个判断结点,需要递归调用getNumLeafs()函数。
    # getNumLeafs()函数遍历整棵树,累计叶子节点的个数,并返回该数值
    # getTreeDepth()函数计算遍历过程中遇到判断结点的个数,该函数的终止条件是叶子节点,一旦到达叶子节点,则从递归调用中返回,并将计算树深度的变量加一。
def getNumLeafs(myTree):
    numLeafs = 0
    firstStr = myTree.keys()[0]
    secondDict = myTree[firstStr]
    for key in secondDict.keys():
        if type(secondDict[key]).__name__ == 'dict':
            numLeafs += getNumLeafs(secondDict[key])
        else:
            numLeafs += 1
    return numLeafs
def getTreeDepth(myTree):
    maxDepth = 0
    firstStr = myTree.keys()[0]
    secondDict = myTree[firstStr]
    for key in secondDict.keys():
        if type(secondDict[key]).__name__ == 'dict':
            thisDepth = 1+getTreeDepth(secondDict[key])
        else:
            thisDepth = 1
        if thisDepth > maxDepth:
            maxDepth = thisDepth
    return maxDepth

# 为节约时间,函数retrieveTree输出预先存储的树信息,避免了每次测试代码时都要从数据中创建树的麻烦
def retrieveTree(i):
    listOfTrees = [{'no surfacing':{0:'no',1:{'flippers':{0:'no',1:'yes'}}}},
                   {'no surfacing':{0:'no',1:{'flippers':{0:{'haed':{0:'no',1:'yes'}},1:'no'}}}}]
    return listOfTrees[i]

# 在父子节点间填充文本信息
def plotMidText(cntrPt,parentPt,txtString):
    xMid = (parentPt[0]-cntrPt[0]/2.0+cntrPt[0])
    yMid = (parentPt[1]-cntrPt[1]/2.0+cntrPt[1])
    createPlot.axl.text(xMid, yMid, txtString)

# 函数plotTree()依次调用了前面介绍的函数和plotMidText(),绘制树形图的很多工作都是在函数plotTree()中完成的
# plotTree()也是个递归函数,树的宽度用于计算放置判断节点的位置,主要的计算原则是将它放在所有叶子节点的中间,而不是它子节点的中间
# 另一个需要说明的问题是,绘制图形的x轴和y轴有效范围是0.0到1.0
# 按照叶子节点数将x轴划分为若干部分,按照凸性比例绘制树形图的最大好处是无需关心实际输出图形的大小,一旦图形大小发生变化,函数会自动按照图形大小重新绘制。。
def plotTree(myTree, parentPt, nodeTxt):
    # 首先计算树的宽和高
    numLeafs = getNumLeafs(myTree)
    depth = getTreeDepth(myTree)
    getTreeDepth(myTree)
    firstStr = myTree.keys()[0]
    # 同时我们使用两个全局变量ployTree.xOfff和ployTree.yOfff追踪已经绘制的节点位置,以及放置下一个节点的恰当位置。
    cntrPt = (plotTree.xOff + (1.0 + float(numLeafs))/2.0/plotTree.totalW, plotTree.yOff)
    # 接着,绘出子节点具有的特征值,或者沿此分支向下的数据实例必须具有的特征值,
    # 使用plotMidText()函数计算父节点和子节点的中间位置,并在此处添加简单的文本标签信息。
    plotMidText(cntrPt, parentPt, nodeTxt)
    plotNode(firstStr,cntrPt,parentPt,decisionNode)
    secondDict = myTree[firstStr]
    # 然后,按比例减少全局变量plotTree.yOff,并标注此处将要绘制子节点
    # 这些结点既可以是叶子节点也可以是判断结点,此处需要只保存绘制图形的轨迹。因为我们是自顶向下绘制图形,因此需要依次递减y坐标值
    plotTree.yOff = plotTree.yOff - 1.0/plotTree.totalD
    for key in secondDict.keys():
        if type(secondDict[key]).__name__ == 'dict':
            plotTree(secondDict[key], cntrPt, str(key))
        else:
            plotTree.xOff = plotTree.xOff + 1.0/plotTree.totalW
            plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode)
            plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key))
    plotTree.yOff = plotTree.yOff + 1.0/plotTree.totalD


# createPlotNew()是主函数。调用了函数plotTree(),
def createPlotNew(inTree):
    fig = plt.figure(1,facecolor = 'white')
    fig.clf()
    axprops = dict(xticks=[], yticks =[])
    createPlot.axl = plt.subplot(111, frameon = False, **axprops)
    # 全局变量plotTree.totalW存储树的宽度,全局变量plotTree.totalD存储树的深度,我们使用这两个变量计算树节点的摆放位置,这样可以将树绘制在水平方向和垂直方向的中心位置
    plotTree.totalW = float(getNumLeafs(inTree))
    plotTree.totalD = float(getTreeDepth(inTree))
    plotTree.xOff = -0.5/plotTree.totalW ; plotTree.yOff = 1.0; plotTree(inTree,(0.5,1.0),'')
    plt.show()

猜你喜欢

转载自blog.csdn.net/weixin_42836351/article/details/81300926