【python和机器学习入门2】决策树3——使用决策树预测隐形眼镜类型

版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/maotianyi941005/article/details/82384428

参考博客:决策树实战篇之为自己配个隐形眼镜 (po主Jack-Cui,《——大部分内容转载自

                 

参考书籍:《机器学习实战》——第三章3.4

《——决策树基础知识见前两篇 ,

摘要:本篇用一个预测隐形眼镜类型的例子讲述如何建树、可视化,并介绍了用sklearn构建决策树的代码

目录

1 数据处理

2  完整代码

3 Matplotlib可视化

4 sklearn构建决策树


1 数据处理

隐形眼镜数据集是非常著名的数据集,它包含很多患者眼部状态的观察条件以及医生推荐的隐形眼镜类型。隐形眼镜类型包括硬材质(hard)、软材质(soft)以及不适合佩戴隐形眼镜(no lenses)。给出一个数据集,使用决策树预测患者的隐形眼镜类型(共三类:hard/soft/no lenses)

lenses.txt数据如下图,共24组数据,5列属性,第5列为隐形眼镜类型,即我们需要预测的分类。

数据labels为[ageprescriptastigmatictearRateclass]

                 即[年龄、症状,是否散光,眼泪数量,最终的分类标签]

'''创建数据集'''
def createDataSet():
    fr = open('lenses.txt')
    dataSet = [rl.strip().split('\t') for rl in  fr.readlines()]
    print dataSet
    labels = ['age','prescript','astigmatic','tearRate']        #特征属性
    return dataSet, labels                #返回数据集和特征属性

2  完整代码

#!/usr/bin/env python
#_*_coding:utf-8_*_
import numpy as np
import json
import operator
from math import log

'''创建数据集'''
def createDataSet():
    fr = open('lenses.txt')
    dataSet = [rl.strip().split('\t') for rl in  fr.readlines()]
    labels = ['age','prescript','astigmatic','tearRate']        #特征属性
    return dataSet, labels                #返回数据集和特征属性

'''经验熵'''
def calShannonEnt(dataset):
    m = len(dataset)
    lableCount = {}
    '''计数'''
    for data in dataset:
        currentLabel = data[-1]
        if currentLabel not in lableCount.keys():
            lableCount[currentLabel] = 0
        lableCount[currentLabel] += 1
    '''遍历字典求和'''
    entropy = 0
    for label in lableCount:
        p = float(lableCount[label]) / m
        entropy -= p * log(p,2)
    return entropy

'''第i个特征根据取值value划分子数据集'''
def splitdataset(dataset,axis,value):
    subSet = []
    for data in dataset:
        if(data[axis] == value):
            data_x = data[:axis]
            data_x.extend(data[axis+1:])
            subSet.append(data_x)
    return subSet

'''遍历数据集求最优IG和特征'''
def chooseBestFeatureToSpit(dataSet):
    feature_num = len(dataSet[0])-1
    origin_ent = calShannonEnt(dataSet)
    infoGain = 0.0
    best_infogain = 0.0
    for i in range(feature_num):
        fi_all = [data[i] for data in dataSet]
        fi_all = set(fi_all)
        #print fi_all
        subset_Ent = 0
        '''遍历所有可能value'''
        for value in fi_all:
            #划分子集
            #print i,value
            subset = splitdataset(dataSet,i,value)
            #print subset
            #计算子集熵
            p = float(len(subset)) / len(dataSet)
            subset_Ent += p * calShannonEnt(subset)
        #计算信息增益
        infoGain = origin_ent - subset_Ent
        #记录最大IG
        #print  "第 %d 个特征的信息增益为 %f" % (i,infoGain)
        if(infoGain > best_infogain):
            best_feature = i
            best_infogain = infoGain
    return best_feature

'''计数并返回最多类别'''
def majorityCnt(classList):
    classCount = {}
    for class_ in classList:
        if(class_ not in classCount.keys()):
            classCount[class_] = 0
        classCount[class_] += 1
    classSort = sorted(classCount.iteritems(),key = operator.itemgetter(1),reverse=True)
    return classSort[0][0]

'''向下递归创建树 '''
def createTree(dataSet,labels,feaLabels):
    '''数据集所有类别'''
    classList = [example[-1] for example in dataSet]
    '''判断是否属于2个终止类型'''
    '''1 全属一个类'''
    if(len(classList) == classList.count(classList[0])):
        return classList[0]
    '''2 只剩1个特征属性'''
    if(len(dataSet[0]) == 1):
        majorClass = majorityCnt(classList)
        return majorClass
    '''继续划分'''
    best_feature = chooseBestFeatureToSpit(dataSet)#最优划分特征 下标号
   
    best_feaLabel = labels[best_feature]
    feaLabels.append(best_feaLabel) #存储最优特征
    del(labels[best_feature])#特征属性中删去最优特征《——ID3消耗特征
    feaValue = [example[best_feature] for example in dataSet]
    feaValue = set(feaValue) #获取最优特征的属性值列表
    deci_tree = {best_feaLabel:{}}#子树的根的key是此次划分的最优特征名,value是再往下递归划分的子树
    for value in feaValue:
        subLabel = labels[:] #因为每个value都需要label,copy以免递归更改
        subset = splitdataset(dataSet,best_feature,value)
        deci_tree[best_feaLabel][value] = createTree(subset,subLabel,feaLabels)
    #print deci_tree
    return deci_tree



if __name__ == '__main__':

    dataSet, labels = createDataSet()
    feaLabels = []
    mytree = createTree(dataSet,labels,feaLabels)
    print json.dumps(mytree,ensure_ascii=False)

建树结果

{"tearRate": {"reduced": "no lenses", "normal": {"astigmatic": {"yes": {"prescript": {"hyper": {"age": {"pre": "no lenses", "presbyopic": "no lenses", "young": "hard"}}, "myope": "hard"}}, "no": {"age": {"pre": "soft", "presbyopic": {"prescript": {"hyper": "soft", "myope": "no lenses"}}, "young": "soft"}}}}}}

3 Matplotlib可视化

上面建树的字典展示看起来很不直观,接下来用matplotlib将结果可视化一下

环境 maxos 10.12.3 python2.7

模块下载,python2

pip install matplotlib

如果是python3

pip3 install matplotlib

代码引入模块

import matplotlib

import matplotlib.pyplot as plt

需要用到的函数:

  • getNumLeafs:获取决策树叶子结点的数目
  • getTreeDepth:获取决策树的层数
  • plotNode:绘制结点
  • plotMidText:标注有向边属性值
  • plotTree:绘制决策树
  • createPlot:创建绘制面板
def getNumLeafs(myTree):
    numLeafs = 0                                                #初始化叶子
    firstStr = next(iter(myTree))                                #python3中myTree.keys()返回的是dict_keys,不在是list,所以不能使用myTree.keys()[0]的方法获取结点属性,可以使用list(myTree.keys())[0]
    secondDict = myTree[firstStr]                                #获取下一组字典
    for key in secondDict.keys():
        if type(secondDict[key]).__name__=='dict':                #测试该结点是否为字典,如果不是字典,代表此结点为叶子结点
            numLeafs += getNumLeafs(secondDict[key])
        else:   numLeafs +=1
    return numLeafs

def getTreeDepth(myTree):
    maxDepth = 0                                                #初始化决策树深度
    firstStr = next(iter(myTree))                                #python3中myTree.keys()返回的是dict_keys,不在是list,所以不能使用myTree.keys()[0]的方法获取结点属性,可以使用list(myTree.keys())[0]
    secondDict = myTree[firstStr]                                #获取下一个字典
    for key in secondDict.keys():
        if type(secondDict[key]).__name__=='dict':                #测试该结点是否为字典,如果不是字典,代表此结点为叶子结点
            thisDepth = 1 + getTreeDepth(secondDict[key])
        else:   thisDepth = 1
        if thisDepth > maxDepth: maxDepth = thisDepth            #更新层数
    return maxDepth

def plotNode2(nodeTxt, centerPt, parentPt, nodeType):
    arrow_args = dict(arrowstyle="<-")                                            #定义箭头格式
    #font = FontProperties(fname=r"c:\windows\fonts\simsun.ttc", size=14)        #设置中文字体
    createPlot.ax1.annotate(nodeTxt, xy=parentPt,  xycoords='axes fraction',    #绘制结点
        xytext=centerPt, textcoords='axes fraction',
        va="center", ha="center", bbox=nodeType, arrowprops=arrow_args)

def plotMidText(cntrPt, parentPt, txtString):
    xMid = (parentPt[0]-cntrPt[0])/2.0 + cntrPt[0]                                            #计算标注位置
    yMid = (parentPt[1]-cntrPt[1])/2.0 + cntrPt[1]
    createPlot.ax1.text(xMid, yMid, txtString, va="center", ha="center", rotation=30)

def plotTree(myTree, parentPt, nodeTxt):
    decisionNode = dict(boxstyle="sawtooth", fc="0.8")                                        #设置结点格式
    leafNode = dict(boxstyle="round4", fc="0.8")                                            #设置叶结点格式
    numLeafs = getNumLeafs(myTree)                                                          #获取决策树叶结点数目,决定了树的宽度
    depth = getTreeDepth(myTree)                                                            #获取决策树层数
    firstStr = next(iter(myTree))                                                            #下个字典
    cntrPt = (plotTree.xOff + (1.0 + float(numLeafs))/2.0/plotTree.totalW, plotTree.yOff)    #中心位置
    plotMidText(cntrPt, parentPt, nodeTxt)                                                    #标注有向边属性值
    plotNode2(firstStr, cntrPt, parentPt, decisionNode)                                        #绘制结点
    secondDict = myTree[firstStr]                                                            #下一个字典,也就是继续绘制子结点
    plotTree.yOff = plotTree.yOff - 1.0/plotTree.totalD                                        #y偏移
    for key in secondDict.keys():
        if type(secondDict[key]).__name__=='dict':                                            #测试该结点是否为字典,如果不是字典,代表此结点为叶子结点
            plotTree(secondDict[key],cntrPt,str(key))                                        #不是叶结点,递归调用继续绘制
        else:                                                                                #如果是叶结点,绘制叶结点,并标注有向边属性值
            plotTree.xOff = plotTree.xOff + 1.0/plotTree.totalW
            plotNode2(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode)
            plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key))
    plotTree.yOff = plotTree.yOff + 1.0/plotTree.totalD

def createPlot(inTree):
    fig = plt.figure(1, facecolor='white')                                                    #创建fig
    fig.clf()                                                                                #清空fig
    axprops = dict(xticks=[], yticks=[])
    createPlot.ax1 = plt.subplot(111, frameon=False, **axprops)                                #去掉x、y轴
    plotTree.totalW = float(getNumLeafs(inTree))                                            #获取决策树叶结点数目
    plotTree.totalD = float(getTreeDepth(inTree))                                            #获取决策树层数
    plotTree.xOff = -0.5/plotTree.totalW; plotTree.yOff = 1.0;                                #x偏移
    plotTree(inTree, (0.5,1.0), '')                                                            #绘制决策树
    plt.show()                                                                                 #显示绘制结果


if __name__ == '__main__':

    dataSet, labels = createDataSet()
    feaLabels = []
    mytree = createTree(dataSet,labels,feaLabels)
    # print json.dumps(mytree,ensure_ascii=False)
    createPlot(mytree)

4 sklearn构建决策树

猜你喜欢

转载自blog.csdn.net/maotianyi941005/article/details/82384428