机器学习实战Chp3:决策树--ID3算法

  • 机器学习实战Chp3:决策树–ID3算法

  • 绘制 树形图 模块
# -*- coding: utf-8 -*-
"""
Created on Tue Jun 26 21:55:37 2018

@author: muli

"""

# 可以作为绘制树的模板,直接调用即可
# 树的形式要求如下:
#    A={'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}}
import matplotlib.pyplot as plt


# 判断节点
decisionNode = dict(boxstyle="sawtooth", fc="0.8")
# 叶子节点
leafNode = dict(boxstyle="round4", fc="0.8")
# 箭头
arrow_args = dict(arrowstyle="<-")


# 绘制带箭头的注解
def plotNode(nodeTxt, centerPt, parentPt, nodeType):
    # annotate(s='str' ,xy=(x,y) ,xytext=(l1,l2) ,..)
    # s 为注释文本内容 
    # xy 为被注释的坐标点,即箭头的起点
    # xytext 为注释文字的坐标位置
    # xycoords来指定点xy坐标的类型,textcoords指定xytext的类型,bbox给标题增加外框
    createPlot.ax1.annotate(nodeTxt, xy=parentPt,  xycoords='axes fraction',
             xytext=centerPt, textcoords='axes fraction',
             va="center", ha="center", bbox=nodeType, arrowprops=arrow_args )


#def createPlot():
#    # 绘制新图形
#    # plt.figure(1)是新建一个名叫 Figure1的画图窗口
#    fig = plt.figure(1, facecolor='white')
#    # 清空绘制区
#    fig.clf()
#    # subplot(nrows, ncols, plot_number)
#    # plt.subplot作用是把一个绘图区域(可以理解成画布)分成多个小区域,用来绘制多个子图。
#    # nrows和ncols表示将画布分成(nrows*ncols)个小区域,每个小区域可以单独绘制图形;
#    # plot_number表示将图绘制在第plot_number个子区域。
#    createPlot.ax1 = plt.subplot(111, frameon=False) #ticks for demo puropses 
#    plotNode('决策节点', (0.5, 0.1), (0.1, 0.5), decisionNode)
#    plotNode('叶子节点', (0.8, 0.1), (0.3, 0.8), leafNode)
#    plt.show()


# 叶子的数目
# 决定 X轴 的长度
# 对叶子节点的数目进行累加
# 可自行测试:
# A={'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}}
def getNumLeafs(myTree):
    numLeafs = 0
    firstStr = myTree.keys()[0]
    secondDict = myTree[firstStr]
    for key in secondDict.keys():
        # 判断键的值,是否是字典 类型
        if type(secondDict[key]).__name__=='dict':
            numLeafs += getNumLeafs(secondDict[key])
        # 如果不是字典类型,则为叶子节点
        else:   
            numLeafs +=1
    return numLeafs


# 树的高度
# 决定 Y轴 的深度
def getTreeDepth(myTree):
    maxDepth = 0
    firstStr = myTree.keys()[0]
    secondDict = myTree[firstStr]
    for key in secondDict.keys():
        # 判断键的值,是否是字典 类型
        if type(secondDict[key]).__name__=='dict':#test to see if the nodes are dictonaires, if not they are leaf nodes
            thisDepth = 1 + getTreeDepth(secondDict[key])
        else:   
            thisDepth = 1
        # 在Y轴上,比较的是 最大值
        if thisDepth > maxDepth: 
            maxDepth = thisDepth
    return maxDepth


# 创建 树的信息
# 主要用于测试
def retrieveTree(i):
    listOfTrees =[{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}},
                  {'no surfacing': {0: 'no', 1: {'flippers': {0: {'head': {0: 'no', 1: 'yes'}}, 1: 'no'}}}}
                  ]
    return listOfTrees[i]


def plotMidText(cntrPt, parentPt, txtString):
    xMid = (parentPt[0]-cntrPt[0])/2.0 + cntrPt[0]
    yMid = (parentPt[1]-cntrPt[1])/2.0 + cntrPt[1]
    createPlot.ax1.text(xMid, yMid, txtString, va="center", ha="center", rotation=30)


# 绘制 树
def plotTree(myTree, parentPt, nodeTxt):#if the first key tells you what feat was split on
    numLeafs = getNumLeafs(myTree)  #this determines the x width of this tree
    depth = getTreeDepth(myTree)
    firstStr = myTree.keys()[0]     #the text label for this node should be this
    cntrPt = (plotTree.xOff + (1.0 + float(numLeafs))/2.0/plotTree.totalW, plotTree.yOff)
    plotMidText(cntrPt, parentPt, nodeTxt)
    plotNode(firstStr, cntrPt, parentPt, decisionNode)
    secondDict = myTree[firstStr]
    plotTree.yOff = plotTree.yOff - 1.0/plotTree.totalD
    for key in secondDict.keys():
        if type(secondDict[key]).__name__=='dict':#test to see if the nodes are dictonaires, if not they are leaf nodes   
            plotTree(secondDict[key],cntrPt,str(key))        #recursion
        else:   #it's a leaf node print the leaf node
            plotTree.xOff = plotTree.xOff + 1.0/plotTree.totalW
            plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode)
            plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key))
    plotTree.yOff = plotTree.yOff + 1.0/plotTree.totalD
#if you do get a dictonary you know it's a tree, and the first element will be another dict


# 绘制树形图
def createPlot(inTree):
    fig = plt.figure(1, facecolor='white')
    fig.clf()
    axprops = dict(xticks=[], yticks=[])
    createPlot.ax1 = plt.subplot(111, frameon=False, **axprops)    #no ticks
    #createPlot.ax1 = plt.subplot(111, frameon=False) #ticks for demo puropses 
    plotTree.totalW = float(getNumLeafs(inTree))
    plotTree.totalD = float(getTreeDepth(inTree))
    plotTree.xOff = -0.5/plotTree.totalW; plotTree.yOff = 1.0;
    plotTree(inTree, (0.5,1.0), '')
    plt.show()


# 测试模块
if __name__ == "__main__" :
#    createPlot()
    myTree=retrieveTree(0)
    print(myTree)
    num_Leafs=getNumLeafs(myTree)
    num_Depth=getTreeDepth(myTree)
    print('叶子节点的数目:'+str(num_Leafs))
    print('树的深度:'+str(num_Depth))
    print("------------------------------")
    createPlot(myTree)
    print("------------------------------")
    myTree['no surfacing'][3]='maybe'
    print(myTree)
    createPlot(myTree)


  • 构建决策树及测试模块
# -*- coding: utf-8 -*-
"""
Created on Tue Jun 26 13:44:00 2018

@author: muli
"""

from math import log
import operator
import pickle
import treePlotter


# 创建数据集
def createDataSet():
    dataSet = [[1, 1, 'yes'],
               [1, 1, 'yes'],
               [1, 0, 'no'],
               [0, 1, 'no'],
               [0, 1, 'no']]
    # 依据上面的数据集,依次定义特征
    labels = ['no surfacing','flippers']
    #change to discrete values
    return dataSet, labels


# 计算香浓熵
def calcShannonEnt(dataSet):
    numEntries = len(dataSet)
    # 创建字典
    labelCounts = {}
    for featVec in dataSet:
        # 得到每个类别的标签
        currentLabel = featVec[-1]
        # 判断字典中是否有 这种记录
        # 没有的话,将记录置为 0
        if currentLabel not in labelCounts.keys(): 
            labelCounts[currentLabel] = 0
        # 类别标签值加 1
        labelCounts[currentLabel] += 1
    shannonEnt = 0.0
    for key in labelCounts:
        # 求每个可能种类的概率
        prob = float(labelCounts[key])/numEntries
        # 香浓公式的计算
        shannonEnt -= prob * log(prob,2)
    return shannonEnt


# 划分数据集
# axis:表示第axis维的特征
# value:表示该维特征的取值
def splitDataSet(dataSet, axis, value):
    retDataSet = []
    # featVec 为矩阵的每一行数据
    for featVec in dataSet:
        # 对行中特定的列,进行比对
        if featVec[axis] == value:
            reducedFeatVec = featVec[:axis]    
            # 截取其余的特征
            reducedFeatVec.extend(featVec[axis+1:])
            # 上面两条语句加起来,则去掉第axis维特征
            # 换句话说,即考虑 剩余的特征
            retDataSet.append(reducedFeatVec)
    return retDataSet


# 选择最好的数据集划分方式
def chooseBestFeatureToSplit(dataSet):
    # 判定数据集有多少个特征
    numFeatures = len(dataSet[0]) - 1      #the last column is used for the labels
    # 计算香浓的信息值期望
    baseEntropy = calcShannonEnt(dataSet)
    bestInfoGain = 0.0; bestFeature = -1
    # 对每个特征进行遍历
    for i in range(numFeatures):        #iterate over all the features
        # 提取每一行的第i个特征
        featList = [example[i] for example in dataSet]#create a list of all the examples of this feature
        # 得到的是,每个特征的属性值,有多个
        # 去重/无序的集合
        uniqueVals = set(featList)       #get a set of unique values
        newEntropy = 0.0
        # 对每个属性值进行计算 信息值的期望
        for value in uniqueVals:
            subDataSet = splitDataSet(dataSet, i, value)
            # 考虑到权重计算,可参考:周志华《机器学习》P75
            prob = len(subDataSet)/float(len(dataSet))
            newEntropy += prob * calcShannonEnt(subDataSet)  
        # 计算信息增益
        infoGain = baseEntropy - newEntropy     #calculate the info gain; ie reduction in entropy
        if (infoGain > bestInfoGain):       #compare this to the best gain so far
            bestInfoGain = infoGain         #if better than current best, set to best
            bestFeature = i
    return bestFeature                      #returns an integer


# 多数表决法
def majorityCnt(classList):
    classCount={}
    for vote in classList:
        if vote not in classCount.keys(): 
            classCount[vote] = 0
        classCount[vote] += 1
    sortedClassCount = sorted(classCount.iteritems(), key=operator.itemgetter(1), reverse=True)
    return sortedClassCount[0][0]


# 创建树的代码
def createTree(dataSet,labels):
    # 获得 类别标签
    classList = [example[-1] for example in dataSet]
    # List中count()方法:用于统计某个元素在列表中出现的次数
    if classList.count(classList[0]) == len(classList): 
        return classList[0]
    # 剩余数据集中,
    if len(dataSet[0]) == 1: 
        return majorityCnt(classList)
    # 经过以上两种情况的判断,若不满足要求,则 选择最好的特征 进行数据集划分
    bestFeat = chooseBestFeatureToSplit(dataSet)
    # 获得特征标签
    bestFeatLabel = labels[bestFeat]
    # myTree:存储树的所有信息
    # 当前存储最好的特征标签在字典中
    myTree = {bestFeatLabel:{}}
    # 删除这个最好的 特征,进行下一步 的 数据集划分
    del(labels[bestFeat])
    # 获得最好特征对应列的所有取值
    featValues = [example[bestFeat] for example in dataSet]
    # 去重/无序的集合
    uniqueVals = set(featValues)
    # 遍历该特征的属性值,有多个,对应多个分支
    for value in uniqueVals:
        # 前面del()删除了一个特征,即 取剩余的所有标签
        subLabels = labels[:]       #copy all of labels, so trees don't mess up existing labels
        # 递归调用 创建树 的函数
        # 本函数的关键地方
        # 前面两个if语句的返回值 是标签--'yes'和'no'
        # 如果是叶子节点,返回值为'yes'或者'no'
        # 如果非叶子结点,选择下一个最好的特征,进一步做递归
        # 直到全部为叶子节点,将值返回
        # 最终函数返回值为 myTree 
        myTree[bestFeatLabel][value] = createTree(
                splitDataSet(dataSet, bestFeat, value),subLabels)
    return myTree   


# 使用决策树分类函数
def classify(inputTree,featLabels,testVec):
    # 取出第一个特征
    firstStr = inputTree.keys()[0]
    # 取出第一个特征对应的值
    secondDict = inputTree[firstStr]
    # 特征标签的索引
    featIndex = featLabels.index(firstStr)
    # 取得测试数据对应的特征属性值--
    key = testVec[featIndex]
    # 由 键 获取值
    valueOfFeat = secondDict[key]
    # 如果是 字典类型,递归调用;否则,返回该值
    if isinstance(valueOfFeat, dict): 
        classLabel = classify(valueOfFeat, featLabels, testVec)
    else: 
        classLabel = valueOfFeat
    return classLabel


# 创建 树的信息
# 主要用于测试
def retrieveTree(i):
    listOfTrees =[{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}},
                  {'no surfacing': {0: 'no', 1: {'flippers': {0: {'head': {0: 'no', 1: 'yes'}}, 1: 'no'}}}}
                  ]
    return listOfTrees[i]


# 使用 pickle 模块 存储决策树
def storeTree(inputTree,filename):
    fw = open(filename,'w')
    pickle.dump(inputTree,fw)
    fw.close()


# 使用 pickle 模块 读取决策树
def grabTree(filename):
    fr = open(filename)
    return pickle.load(fr)


#fr=open('lenses.txt')
#lenses=[inst.strip().split('\t') for inst in fr.readlines()]
#lensesLabels=['age','prescript','astigmatic','tearRate']
#lensesTree=createTree(lenses,lensesLabels)
#print(lensesTree)
##    createPlot
##    NameError: name 'createPlot' is not defined
#createPlot(lensesTree)



# 测试模块
if __name__ == "__main__" :
    myDat,labels=createDataSet()
    print(myDat)
    print(labels)
    print("--------------------------")
#    pf=calcShannonEnt(myDat)
#    print(pf)
#    print("--------------------------")
#    myDat[0][-1]='maybe'
#    pf=calcShannonEnt(myDat)
#    print(pf)
##    print("---------------------------")
#    re=splitDataSet(myDat,0,1)
#    print(re)
#    re=splitDataSet(myDat,0,0)
#    print(re)
#    x=chooseBestFeatureToSplit(myDat)
#    print("最好的特征是第:"+str(x)+" 个")
#    print("---------------------------")
#    print(labels)
#    myTree=createTree(myDat,labels)
#    print(myTree)
#    myTree=retrieveTree(0)
#    print(myTree)
#    print("---------------------------")
#    print("测试如下:")
#    res=classify(myTree,labels,[1,0])
#    print(res)
#    res=classify(myTree,labels,[1,1])
#    print(res)
#    print("---------------------------")
#    myTree=createTree(myDat,labels)
#    print("正在存储树模型!")
#    storeTree(myTree,'classifierStorage.txt')
#    print("正在读取树模型!")
#    get_Tree=grabTree('classifierStorage.txt')
#    print(get_Tree)
    print("************************************")
    fr=open('lenses.txt')
    lenses=[inst.strip().split('\t') for inst in fr.readlines()]
    lensesLabels=['age','prescript','astigmatic','tearRate']
    lensesTree=createTree(lenses,lensesLabels)
    print(lensesTree)
#    createPlot
#    NameError: name 'createPlot' is not defined
    createPlot(lensesTree)

猜你喜欢

转载自blog.csdn.net/mr_muli/article/details/81135767