小白学习机器学习---第四章:决策树(2)

附上实现的ID3算法python代码~~~

参考机器学习实战写的

#-*- coding: UTF-8 -*-
from math import log
import operator
from matplotlib.font_manager import FontProperties
import matplotlib.pyplot as plt
import copy
#创建测试数据
def createDataSet():
    dataSet=[['young', 0, 0, 0, 'no'],         #数据集,no代表不给贷款,yes代表给贷款
            ['young', 0, 0, 1, 'no'],
            ['young', 1, 0, 1, 'yes'],
            ['young', 1, 1, 0, 'yes'],
            ['young', 0, 0, 0, 'no'],
            ['middle', 0, 0, 0, 'no'],
            ['middle', 0, 0, 1, 'no'],
            ['middle', 1, 1, 1, 'yes'],
            ['middle', 0, 1, 2, 'yes'],
            ['middle', 0, 1, 2, 'yes'],
            ['old', 0, 1, 2, 'yes'],
            ['old', 0, 1, 1, 'yes'],
            ['old', 1, 0, 1, 'yes'],
            ['old', 1, 0, 2, 'yes'],
            ['old', 0, 0, 0, 'no']]
    labels=['年龄','有工作','有房子','贷款情况']#贷款情况,0,1,2代表一般,好,非常好
    
    return dataSet,labels

#计算信息熵
def calShannonEnt(dataSet):
    labelCounts={}
    for item in dataSet:
        label=item[-1]
        if(label not in labelCounts.keys()):
            labelCounts[label]=1
        else:
            labelCounts[label]+=1
    length=len(dataSet)
    shannonEnt=0.0
    for i in labelCounts:
        p=labelCounts[i]/length
        shannonEnt-=p*log(p,2)
    return shannonEnt

###按照给定的特征划分数据集
def splitDataSet(dataSet,index,value):#index为特征的索引,value为要选出的特征值:
    returnData=[]
    for item in dataSet:
        if(item[index]==value):
            item2=item[:index]
            item2.extend(item[index+1:])
            returnData.append(item2)
    return returnData

###选择最优特征
def chooseBestFeatureToSplit(dataSet):
    featureNum=len(dataSet[0])-1
    baseEnt=calShannonEnt(dataSet)
    maxGain=0.0
    bestFeature=-1
    for i in range(featureNum):
        #先统计i列特征有几种取值
        featureValues=[]
        currentEnt=0.0
        for item in dataSet:
            featureValues.append(item[i])
        featureValues=set(featureValues)
        
        #对每种取值进行数据划分并计算熵
        for value in featureValues:
            splitData=splitDataSet(dataSet,i,value)
            p=len(splitData)/len(dataSet)
            ent=calShannonEnt(splitData)
            currentEnt+=p*ent
        currentGain=baseEnt-currentEnt
        print("第%d个特征的增益为%.3f" % (i, currentGain))
        if(maxGain<currentGain):
            maxGain=currentGain
            bestFeature=i
    return bestFeature
        
###统计classList中出现此处最多的元素
def majorityCnt(classList):
    classCount={}
    for item in classList:
        if item not in classCount.keys():
            classCount[item]=1
        else:
            classCountp[item]+=1
    sortedClassCount=sorted(classCount.items(),key=operator.itemgetter(1),reverse=True)
    print(sortedClassCount)
    return sortedClassCount[0][0]

#创建决策树
"""
函数说明:创建决策树

Parameters:
    dataSet - 训练数据集
    labels - 分类属性标签
    featLabels - 存储选择的最优特征标签
在构建决策树的代码,可以看到,有个featLabels参数。
它是用来干什么的?它就是用来记录各个分类结点的,在用决策树做预测的时候,我们按顺序输入需要的分类结点的属性值即可。
Returns:
    myTree - 决策树

"""
def createTree(dataSet,labels,featLabels):
    classList=[example[-1] for example in dataSet]
    if(classList.count(classList[0])==len(classList)):
        return classList[0]
    if(len(dataSet[0])==1):
        return majorityCnt(classList)
    bestFeat=chooseBestFeatureToSplit(dataSet)
    bestFeatLabel=labels[bestFeat]
    featLabels.append(bestFeatLabel)
    myTree={bestFeatLabel:{}}
    del(labels[bestFeat])
    #得到训练集中所有最优特征的属性值
    featValues=[example[bestFeat] for example in dataSet]
    featValues=set(featValues)
    
    for value in featValues:
        myTree[bestFeatLabel][value]=createTree(splitDataSet(dataSet,bestFeat,value),labels,featLabels)
    return myTree

"""
函数说明:获取决策树叶子结点的数目

Parameters:
    myTree - 决策树
Returns:
    numLeafs - 决策树的叶子结点的数目

"""
def getNumLeafs(myTree):
    numLeafs=0
    firstStr=next(iter(myTree))
    secondDict=myTree[firstStr]
    for key in secondDict.keys():
        if(type(secondDict[key]).__name__=='dict'):
            numLeafs+=getNumLeafs(secondDict[key])
        else:
            numLeafs+=1
    return numLeafs

"""
函数说明:获取决策树的层数

Parameters:
    myTree - 决策树
Returns:
    maxDepth - 决策树的层数

"""
def getTreeDepth(myTree):
    maxDepth=0
    firstStr=next(iter(myTree))
    secondDict=myTree[firstStr]
    for key in secondDict.keys():
        if(type(secondDict[key]).__name__=='dict'):
            thisDepth=1+getTreeDepth(secondDict[key])
        else:
            thisDepth=1
        if(thisDepth>maxDepth):
            maxDepth=thisDepth
    return maxDepth

"""
函数说明:绘制结点

Parameters:
    nodeTxt - 结点名
    centerPt - 文本位置
    parentPt - 标注的箭头位置
    nodeType - 结点格式
Returns:
    无

"""
def plotNode(nodeTxt,centerPt,parentPt,nodeType):
    arrow_args = dict(arrowstyle="<-")                                            #定义箭头格式
    font = FontProperties(fname=r"c:\windows\fonts\simsun.ttc", size=14)        #设置中文字体
    createPlot.ax1.annotate(nodeTxt, xy=parentPt,  xycoords='axes fraction',    #绘制结点
        xytext=centerPt, textcoords='axes fraction',
        va="center", ha="center", bbox=nodeType, arrowprops=arrow_args, FontProperties=font)

"""
函数说明:标注有向边属性值

Parameters:
    cntrPt、parentPt - 用于计算标注位置
    txtString - 标注的内容
Returns:
    无

"""

def plotMidText(cntrPt,parentPt,txtString):
    #计算标注位置
    xMid=(parentPt[0]-cntrPt[0])/2+cntrPt[0]
    yMid = (parentPt[1]-cntrPt[1])/2.0 + cntrPt[1]
    createPlot.ax1.text(xMid, yMid, txtString, va="center", ha="center", rotation=30)

    
"""
函数说明:绘制决策树

Parameters:
    myTree - 决策树(字典)
    parentPt - 标注的内容
    nodeTxt - 结点名
Returns:
    无

"""

def plotTree(myTree, parentPt, nodeTxt):
    decisionNode = dict(boxstyle="sawtooth", fc="0.8")                                        #设置结点格式
    leafNode = dict(boxstyle="round4", fc="0.8")                                            #设置叶结点格式
    numLeafs = getNumLeafs(myTree)                                                          #获取决策树叶结点数目,决定了树的宽度
    depth = getTreeDepth(myTree)                                                            #获取决策树层数
    firstStr = next(iter(myTree))                                                            #下个字典                                                 
    cntrPt = (plotTree.xOff + (1.0 + float(numLeafs))/2.0/plotTree.totalW, plotTree.yOff)    #中心位置
    plotMidText(cntrPt, parentPt, nodeTxt)                                                    #标注有向边属性值
    plotNode(firstStr, cntrPt, parentPt, decisionNode)                                        #绘制结点
    secondDict = myTree[firstStr]                                                            #下一个字典,也就是继续绘制子结点
    plotTree.yOff = plotTree.yOff - 1.0/plotTree.totalD                                        #y偏移
    for key in secondDict.keys():                               
        if type(secondDict[key]).__name__=='dict':                          #测试该结点是否为字典,如果不是字典,代表此结点为叶子结点
            plotTree(secondDict[key],cntrPt,str(key))                       #不是叶结点,递归调用继续绘制
        else:                                                               #如果是叶结点,绘制叶结点,并标注有向边属性值                                             
            plotTree.xOff = plotTree.xOff + 1.0/plotTree.totalW
            plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode)
            plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key))
    plotTree.yOff = plotTree.yOff + 1.0/plotTree.totalD

    
"""
函数说明:创建绘制面板

Parameters:
    inTree - 决策树(字典)
Returns:
    无

"""
def createPlot(inTree):
    fig = plt.figure(1, facecolor='white')                                                    #创建fig
    fig.clf()                                                                                #清空fig
    axprops = dict(xticks=[], yticks=[])
    createPlot.ax1 = plt.subplot(111, frameon=False, **axprops)                                #去掉x、y轴
    plotTree.totalW = float(getNumLeafs(inTree))                                            #获取决策树叶结点数目
    plotTree.totalD = float(getTreeDepth(inTree))                                            #获取决策树层数
    plotTree.xOff = -0.5/plotTree.totalW; plotTree.yOff = 1.0;                                #x偏移
    plotTree(inTree, (0.5,1.0), '')                                                            #绘制决策树
    plt.show()  

    
"""
函数说明:使用决策树分类

Parameters:
    inputTree - 已经生成的决策树
    featLabels - 特征标签
    testVec - 测试数据列表
Returns:
    classLabel - 分类结果

"""
def classify(inputTree,featLabels,testVec):
    firstStr=next(iter(inputTree))
    secondDict=inputTree[firstStr]
    featIndex=featLabels.index(firstStr)
    for key in secondDict.keys():
        if(testVec[featIndex]==key):
            if(type(secondDict[key]).__name__=='dict'):
                classLabel=classify(secondDict[key],featLabels,testVec)
            else:
                classLabel=secondDict[key]
    return classLabel

if __name__=='__main__':
    dataSet,labels=createDataSet()
    labelTemp=copy.copy(labels)
    print(dataSet)
    print(calShannonEnt(dataSet))
    print("最优特征索引值:"+str(chooseBestFeatureToSplit(dataSet)))
    featLabels=[]
    myTree=createTree(dataSet,labels,featLabels)
    print(myTree)
    createPlot(myTree)  
    testVec = [0,1,0,1]                                        #测试数据
    result = classify(myTree, labelTemp, testVec)
    if result == 'yes':
        print('放贷')
    if result == 'no':
        print('不放贷')   

猜你喜欢

转载自blog.csdn.net/hx14301009/article/details/79727945