[5] python uses Matplotlib annotations to draw tree diagrams

matplotlib provides an annotation tool annotations that can add text annotations to data graphics.

1. Draw tree nodes with text annotations

import matplotlib.pyplot as plt

decisionNode = dict(boxstyle="sawtooth", fc="0.8")
leafNode = dict(boxstyle="round4", fc="0.8")
arrow_args = dict(arrowstyle="<-")

def plotNode(nodeTxt, centerPt, parentPt, nodeType):
    createPlot.ax1.annotate(nodeTxt, xy=parentPt,  xycoords='axes fraction',
             xytext=centerPt, textcoords='axes fraction',
             va="center", ha="center", bbox=nodeType, arrowprops=arrow_args )
    

def createPlot():
    fig = plt.figure(1, facecolor='white')
    fig.clf()
    createPlot.ax1 = plt.subplot(111, frameon=False) #ticks for demo puropses 
    plotNode('a decision node', (0.5, 0.1), (0.1, 0.5), decisionNode)
    plotNode( ' a leaf node ' , (0.8, 0.1), (0.3, 0.8 ), leafNode)
    plt.show()
Use text annotations
  • dict

Detailed explanation: http://www.cnblogs.com/yangyongzhi/archive/2012/09/17/2688326.html

Dictionary initialization:

>>> d=dict(name='vi',age=20)
>>> d
{'name': 'vi', 'age': 20}

Access the dictionary:

dict = {'Name': 'Zara', 'Age': 7, 'Class': 'First'};
>>> print("dict[name]",dict['Name'])
dict[name] Zara

Draw a marked arrow through the annotate() function; the two positions are the coordinates of the arrow and the tail of the arrow, followed by information such as color.

 

2. Construct the annotation tree

First get the number of leaf nodes and the number of levels of the tree:

def getNumLeafs(myTree):
    numLeafs = 0
    firstStr = list(myTree.keys())[0]
    secondDict = myTree[firstStr]
    for key in secondDict.keys():
        if type(secondDict[key]).__name__=='dict':#test to see if the nodes are dictonaires, if not they are leaf nodes
            numLeafs += getNumLeafs(secondDict[key])
        else:   numLeafs +=1
    return numLeafs

def getTreeDepth(myTree):
    maxDepth = 0
#    firstStr = myTree.keys()[0]
    firstStr = list(myTree.keys())[0]
    secondDict = myTree[firstStr]
    for key in secondDict.keys():
        if type(secondDict[key]).__name__=='dict':#test to see if the nodes are dictonaires, if not they are leaf nodes
            thisDepth = 1 + getTreeDepth(secondDict[key])
        else:   thisDepth = 1
        if thisDepth > maxDepth: maxDepth = thisDepth
    return maxDepth
The number of leaf nodes and the number of levels of the tree

Traverse the node to determine whether it is a dictionary type, if it is a dictionary type, it is a leaf node

Then draw the structure of the tree

def createPlot (inTree):
    fig = plt.figure(1, facecolor='white')
    fig.clf()
    axprops = dict(xticks=[], yticks=[])
    createPlot.ax1 = plt.subplot(111, frameon=False, **axprops)    #no ticks
    #createPlot.ax1 = plt.subplot(111, frameon=False) #ticks for demo puropses 
    plotTree.totalW = float(getNumLeafs(inTree))
    plotTree.totalD = float (getTreeDepth (inTree))
    plotTree.xOff = -0.5 / plotTree.totalW; plotTree.yOff = 1.0 ;
    plotTree (inTree, ( 0.5,1.0), '' )
    plt.show()
    
def plotMidText(cntrPt, parentPt, txtString):
    xMid = (parentPt[0]-cntrPt[0])/2.0 + cntrPt[0]
    yMid = (parentPt[1]-cntrPt[1])/2.0 + cntrPt[1]
    createPlot.ax1.text(xMid, yMid, txtString, va="center", ha="center", rotation=30)

def plotTree(myTree, parentPt, nodeTxt):#if the first key tells you what feat was split on
    numLeafs = getNumLeafs(myTree)  #this determines the x width of this tree
    depth = getTreeDepth(myTree)
    firstStr = list(myTree.keys())[0]     #the text label for this node should be this
    cntrPt = (plotTree.xOff + (1.0 + float(numLeafs))/2.0/plotTree.totalW, plotTree.yOff)
    plotMidText(cntrPt, parentPt, nodeTxt)
    plotNode(firstStr, cntrPt, parentPt, decisionNode)
    secondDict = myTree[firstStr]
    plotTree.yOff = plotTree.yOff - 1.0/plotTree.totalD
    for key in secondDict.keys():
        if type(secondDict[key]).__name__=='dict':#test to see if the nodes are dictonaires, if not they are leaf nodes   
            plotTree(secondDict[key],cntrPt,str(key))        #recursion
        else:   #it's a leaf node print the leaf node
            plotTree.xOff = plotTree.xOff + 1.0/plotTree.totalW
            plotNode (secondDict [key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode)
            plotMidText ((plotTree.xOff, plotTree.yOff), cntrPt, str (key))
    plotTree.yOff = plotTree.yOff + 1.0/plotTree.totalD
#if you do get a dictonary you know it's a tree, and the first element will be another dict
draw tree

The resulting graph is:

Guess you like

Origin http://43.154.161.224:23101/article/api/json?id=324516914&siteId=291194637