机器学习实战---使用Matplotlib注解绘制树形图

1.首先使用Matplotlib.pyplot模块中的annotate( )函数,使用其注释功能来画树的结点

import matplotlib.pyplot as plt
decisionNode = dict(boxstyle = "sawtooth",fc="0.8")
leafNode = dict(boxstyle = "round4",fc="0.8")
arrow_args = dict(arrowstyle="<-")
#使用dict()函数,构建结点和箭头的属性字典,将来用作annotate()函数中,参数bbox和参数arrowprops的值
def plotNode(nodeTxt,centerPt,parentPt,nodeType):
    createPlot.ax1.annotate(nodeTxt, xy=parentPt, xycoords='axes fraction',xytext=centerPt, 
    textcoords='axes fraction',va="center", ha="center", bbox=nodeType, arrowprops=arrow_args )
#createPlot.ax1 是一个全局变量,已经在后面的createPlot函数中定义了绘图区
#构建函数plotNode( ),参数:
#     nodeTxt :注释内容,是个字符串
#     centerPt :注释的位置
#     parentPt :被注释点的位置
#     nodeType :注释结点的外形
#函数annotate()的参数:
#     xy和xytext:分别代表点的位置和注释的位置
#     xycoords和textcoords:分别表示对点和注释坐标的说明
#     va和ha:分别代表水平和竖直方向的对齐
#     bbox:矩形的属性字典
#     arrowprops:箭头的属性字典


关于annotate( )函数的具体用法如下:

>>> help(pyplot.annotate)

Help on function annotate in module matplotlib.pyplot:

annotate(*args, **kwargs)
    call signature::

      annotate(s, xy, xytext=None, xycoords='data',
               textcoords='data', arrowprops=None, **kwargs)

    Keyword arguments:

#关键词参数的描述如下:
    Annotate the *x*, *y* point *xy* with text *s* at *x*, *y*
    location *xytext*.  (If *xytext* = *None*, defaults to *xy*,

    and if *textcoords* = *None*, defaults to *xycoords*).

#给点xy在xytext位置处加文本内容s的注释(如果参数*xytext* = *None*,则其默认等于*xy*;如果参数*textcoords* = *None*,则其默认等于*xycoords*)

    *arrowprops*, if not *None*, is a dictionary of line properties
    (see :class:`matplotlib.lines.Line2D`) for the arrow that connects

    annotation to the point.

#参数*arrowprops*,如果不是设置成*None*,则是箭头的关于线条属性的字典,而箭头呢,用来连接点和注释。

    If the dictionary has a key *arrowstyle*, a FancyArrowPatch
    instance is created with the given dictionary and is
    drawn. Otherwise, a YAArow patch instance is created and
    drawn. 

#如果此字典有键*arrowstyle*,将会由此字典创建同时画出一个‘FancyArrowPatch’的实体。否则会生成一个YAArow的实体(FancyArrowPatch和YAArow是啥箭头,没懂 --#

Valid keys for YAArow are

#关于YAArow属性字典的有效键及其描述如下:

    =========   =========================================================
    Key         Description
    =========   =========================================================
    width       the width of the arrow in points
    frac        the fraction of the arrow length occupied by the head
    headwidth   the width of the base of the arrow head in points
    shrink      oftentimes it is convenient to have the arrowtip
                and base a bit away from the text and point being
                annotated.  If *d* is the distance between the text and
                annotated point, shrink will shorten the arrow so the tip
                and base are shink percent of the distance *d* away from the
                endpoints.  ie, ``shrink=0.05 is 5%``
    ?           any key for :class:`matplotlib.patches.polygon`
    =========   ========================================================

    Valid keys for FancyArrowPatch are

#关于FancyArrowPatch属性字典的有效键及其描述如下:

    ===============  ===================================================
    Key              Description
    ===============  ===================================================
    arrowstyle       the arrow style
    connectionstyle  the connection style
    relpos           default is (0.5, 0.5)
    patchA           default is bounding box of the text
    patchB           default is None
    shrinkA          default is 2 points
    shrinkB          default is 2 points
    mutation_scale   default is text size (in points)
    mutation_aspect  default is 1.
    ?                any key for :class:`matplotlib.patches.PathPatch`
    ===============  ==================================================


    *xycoords* and *textcoords* are strings that indicate the
    coordinates of *xy* and *xytext*.

#参数*xycoords* 和 *textcoords*是字符串,分别用来说明参数*xy*和参数*xytext*的坐标

    =================   ==============================================
    Property            Description
    =================   ==============================================
    'figure points'     points from the lower left corner of the figure
    'figure pixels'     pixels from the lower left corner of the figure
    'figure fraction'   0,0 is lower left of figure and 1,1 is upper, right
    'axes points'       points from lower left corner of axes
    'axes pixels'       pixels from lower left corner of axes
    'axes fraction'     0,1 is lower left of axes and 1,1 is upper right
    'data'              use the coordinate system of the object being
                        annotated (default)
    'offset points'     Specify an offset (in points) from the *xy* value


    'polar'             you can specify *theta*, *r* for the annotation,
                        even in cartesian plots.  Note that if you
                        are using a polar axes, you do not need
                        to specify polar for the coordinate
                        system since that is the native "data" coordinate
                        system.
    =================   ============================================


    If a 'points' or 'pixels' option is specified, values will be
    added to the bottom-left and if negative, values will be
    subtracted from the top-right.  Eg::

#这里没太懂 --#

      # 10 points to the right of the left border of the axes and
      # 5 points below the top border
      xy=(10,-5), xycoords='axes points'


    You may use an instance of
    :class:`~matplotlib.transforms.Transform` or
    :class:`~matplotlib.artist.Artist`. See
    :ref:`plotting-guide-annotation` for more details.


    The *annotation_clip* attribute contols the visibility of the
    annotation when it goes outside the axes area. If True, the
    annotation will only be drawn when the *xy* is inside the
    axes. If False, the annotation will always be drawn regardless
    of its position.  The default is *None*, which behave as True
    only if *xycoords* is"data".


    Additional kwargs are Text properties:

#另外需要补充的关键字参数都是关于文本属性的,如下:

      agg_filter: unknown
      alpha: float (0.0 transparent through 1.0 opaque)
      animated: [True | False]
      axes: an :class:`~matplotlib.axes.Axes` instance
      backgroundcolor: any matplotlib color
      bbox: rectangle prop dict
      clip_box: a :class:`matplotlib.transforms.Bbox` instance
      clip_on: [True | False]
      clip_path: [ (:class:`~matplotlib.path.Path`,         :class:`~matplotlib.
transforms.Transform`) |         :class:`~matplotlib.patches.Patch` | None ]


      color: any matplotlib color
      contains: a callable function
      family or fontfamily or fontname or name: [ FONTNAME | 'serif' | 'sans-ser
if' | 'cursive' | 'fantasy' | 'monospace' ]
      figure: a :class:`matplotlib.figure.Figure` instance
      fontproperties or font_properties: a :class:`matplotlib.font_manager.FontP
roperties` instance
      gid: an id string
      horizontalalignment or ha: [ 'center' | 'right' | 'left' ]
      label: any string
      linespacing: float (multiple of font size)
      lod: [True | False]
      multialignment: ['left' | 'right' | 'center' ]
      path_effects: unknown
      picker: [None|float|boolean|callable]
      position: (x,y)
      rasterized: [True | False | None]
      rotation: [ angle in degrees | 'vertical' | 'horizontal' ]
      rotation_mode: unknown
      size or fontsize: [ size in points | 'xx-small' | 'x-small' | 'small' | 'm
edium' | 'large' | 'x-large' | 'xx-large' ]
      snap: unknown
      stretch or fontstretch: [ a numeric value in range 0-1000 | 'ultra-condens
ed' | 'extra-condensed' | 'condensed' | 'semi-condensed' | 'normal' | 'semi-expa
nded' | 'expanded' | 'extra-expanded' | 'ultra-expanded' ]
      style or fontstyle: [ 'normal' | 'italic' | 'oblique']
      text: string or anything printable with '%s' conversion.
      transform: :class:`~matplotlib.transforms.Transform` instance
      url: a url string
      variant or fontvariant: [ 'normal' | 'small-caps' ]
      verticalalignment or va or ma: [ 'center' | 'top' | 'bottom' | 'baseline'
]
      visible: [True | False]
      weight or fontweight: [ a numeric value in range 0-1000 | 'ultralight' | '
light' | 'normal' | 'regular' | 'book' | 'medium' | 'roman' | 'semibold' | 'demi
bold' | 'demi' | 'bold' | 'heavy' | 'extra bold' | 'black' ]
      x: float
      y: float
      zorder: any number

    .. plot:: mpl_examples/pylab_examples/annotation_demo2.py

2.为了确定x轴的长度,和y轴的高度,我们需要知道树的叶子结点数目和树的层数,下面定义两个函数getNumLeafs( )和getTreeDepth( )来求解树的叶子结点数目和树的层数:

def getNumLeafs(myTree):
	numLeafs = 0
	firstStr = myTree.keys()[0]   	#获取树的根节点,既第一个划分特征
	secondDict = myTree[firstStr]	
	for key in secondDict.keys():	#key是当前根结点的不同取值
		if type(secondDict[key])._name_=='dict':	#secondDict[key]代表子树(的根结点)
			numLeafs += getNumLeafs(secondDict[key])  #如果子树是字典,则递归的调用getNumLeafs函数求解叶子结点
		else:	numLeafs +=1						  #如果子树是叶子结点,则当前numLeafs+1
	return numLeafs
	
def getTreeDepth(myTree):
	maxDepth = 0
	firstStr = myTree.keys()[0]
	secondDict = myTree[firstStr]
	for key in secondDict.keys():	#for循环遍历当前根结点的所有子树
		if type(secondDict[key])._name_=='dict':
			thisDepth = 1+getTreeDepth(secondDict[key]) 	#如果子树是字典,则递归的求解子树的层数
		else: thisDepth =1 
		if thisDepth > maxDepth:	maxDepth = thisDepth
	return maxDepth

3.有了以上的plotNode( )函数,getNumLeafs( )和getTreeDepth( )函数后,可以开始构造plotTree( )函数绘制决策树

构建函数retrieveTree( ),存储树信息,稍后用来测试代码:

def retrieveTree(i):
    listOfTrees =[{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}},
                  {'no surfacing': {0: 'no', 1: {'flippers': {0: {'head': {0: 'no', 1: 'yes'}}, 1: 'no'}}}}
                  ] #字典嵌套在列表中,注意这里有列表中有两个字典,既有两棵树
    return listOfTrees[i]

构建函数plotMidText( ),在父子结点间填充文本信息:

def plotMidText(cntrPt,parentPt,txtString): 
	xMid = (parentPt[0]-cntrPt[0])/2.0 + cntrPt[0]
	yMid = (parentPt[1]-cntrPt[1])/2.0 + cntrPt[1]
	createPlot.ax1.text(xMid,yMid,txtString) #Figure对象调用text方法

构建绘图函数plotTree( ):

参考了一个大神的blog,明白了细节过程:

https://www.cnblogs.com/fantasy01/p/4595902.html

def plotTree(myTree,parentPt,nodeTxt):
	numLeafs = getNumLeafs(myTree)  #当前子树的叶子结点数目
	depth = getTreeDepth(myTree)	#当前子树的深度
	firstStr = myTree.keys()[0]
	cntrPt = (plotTree.xOff + (1.0 + float(numLeafs))/2.0/plotTree.totalW, plotTree.yOff) #由当前结点的叶子结点,确定当前结点的位置(这个比较核心)
	#plotTree.xOff:定义为最近一次绘制叶子结点的坐标(还未到当前,注意初始值)(这个的理解很关键)
	#plotTree.yOff:当前绘制深度y坐标
	#plotTree.totalW:整棵树的叶子结点树
	
	plotMidText(cntrPt,parentPt,nodeTxt) 
	plotNode(firstStr,cntrpt,parentPt,decisionNode) #到此,当前带箭头的注释(也就是当前结点)绘制完
	
	secondDict = myTree[firstStr]
	plotTree.yOff=plot.yOff-1.0/plotTree.totalD #当前绘制完就递减一份
	
	for key in secondDict.keys():
		if type(secondDict[key]).__name__=='dict':    #如果不是叶结点,就递归的去画
            plotTree(secondDict[key],cntrPt,str(key))   
        else:   
            plotTree.xOff = plotTree.xOff + 1.0/plotTree.totalW	#更新叶子结点的x坐标
            plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode) #绘制叶子结点和一个箭头
            plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key))
    plotTree.yOff = plotTree.yOff + 1.0/plotTree.totalD  #for循环结束,代表这树的一层结束了,更新当前绘制深度


构建主函数createPlot( ):

def createPlot(inTree):
	fig = plt.figure(1,facecolor = 'white') #创建一块背景为白色的区域
	fig.clf()  #清空绘图区域
	axprops = dict(xticks=[],ytick=[]) #创建一个轴属性的字典,去掉坐标轴,下面作为sbuplot函数的参数
	createPlot.ax1 = plt.subplot(111,frameon=False,**axprops) 
	#关于函数中参数(**)的用法:https://www.cnblogs.com/empty16/p/6229538.html
	plotTree.totalW = float(getNumLeafs(inTree))	#全局变量
	plotTree.totalD = float(getTreeDepth(inTree))	#全局变量
	plotTree.xOff =-0.5/plotTree.totalW; plotTree.yOff = 1.0 #全局变量初始值
	plotTree(inTree,(0.5,1.0),'')
	plt.show()



猜你喜欢

转载自blog.csdn.net/carl95271/article/details/80365187
今日推荐