Python创建决策树—解决隐形眼镜选择问题

现在我们碰到这样一个问题,一个人去医院想配一副隐形眼镜。我们需要通过问他4个问题,决定他需要带眼镜的类型。那么如何解决这个问题呢?我们决定用决策树。首先我们去下载一个隐形眼镜数据集,数据来源于UCI数据库。下载了lenses.data文件,如下:

1  1  1  1  1  3
2  1  1  1  2  2
3  1  1  2  1  3
4  1  1  2  2  1
5  1  2  1  1  3
6  1  2  1  2  2
7  1  2  2  1  3
8  1  2  2  2  1
9  2  1  1  1  3
10  2  1  1  2  2
11  2  1  2  1  3
12  2  1  2  2  1
13  2  2  1  1  3
14  2  2  1  2  2
15  2  2  2  1  3
16  2  2  2  2  3
17  3  1  1  1  3
18  3  1  1  2  3
19  3  1  2  1  3
20  3  1  2  2  1
21  3  2  1  1  3
22  3  2  1  2  2
23  3  2  2  1  3
24  3  2  2  2  3

我们可以看到,第一列的1到24,对应数据的ID

第二列的1到3,分别对应病人的年龄(age of patient),分别是青年(young),中年(pre-presbyopic),老年(presbyopic)

第三列的1和2,分别对应近视情况(spectacle prescription),近视(myope),远视(hypermetrope)

第四列的1和2,分别对应眼睛是否散光(astigmatic),不散光(no),散光(yes)

第五列的1和2,分别对应分泌眼泪的频率(tear production rate),很少(reduce),普通(normal)

第六列的1到3,则是最终根据以上数据得到的分类,分别是硬性的隐形眼镜(hard),软性的隐形眼镜(soft),不需要带眼镜(no lenses)

数据我们获取到了,那么我们写一个函数去打开文件设定好数据集,以下是代码:

from numpy import *
import operator
from math import log

def createLensesDataSet():#创建隐形眼镜数据集
    fr = open('lenses.data')
    allLinesArr = fr.readlines()
    linesNum = len(allLinesArr)
    returnMat = zeros((linesNum, 4))
    statusLabels = ['age of the patient', 'spectacle prescription', 'astigmatic', 'tear production rate']
    classLabelVector = []
    classLabels = ['hard', 'soft', 'no lenses']

    index = 0
    for line in allLinesArr:
        line = line.strip()
        lineList = line.split('  ')
        returnMat[index, :] = lineList[1:5]
        classIndex = int(lineList[5]) - 1
        classLabelVector.append(classLabels[classIndex])  # 索引-1代表列表最后一个元素
        index += 1

    return ndarray.tolist(returnMat), statusLabels, classLabelVector

def createLensesAttributeInfo():
    parentAgeList = ['young', 'pre', 'presbyopic']
    spectacleList = ['myope', 'hyper']
    astigmaticList = ['no', 'yes']
    tearRateList = ['reduced', 'normal']
    return parentAgeList, spectacleList, astigmaticList, tearRateList

那么接下来我们应该设定决策树的分支,如何确定以上哪一个特征是第一个分支呢,我们要提到一个概念,香农熵(Shannon entropy)。熵这个概念代表信息的不确定性的大小,在划分数据集中经常会运用到。

它的公式是:

那么我们先写一个计算香农熵的函数:

def calcShannonEnt(dataSet):#计算香农熵
    numEntries = len(dataSet)
    labelCounts = {}
    for featVec in dataSet:
        currentLabel = featVec[-1]
        if currentLabel not in labelCounts.keys():
            labelCounts[currentLabel] = 0
        labelCounts[currentLabel] += 1
    shannonEnt = 0.0
    for key in labelCounts:
        prob = float(labelCounts[key])/numEntries
        shannonEnt -= prob * log(prob, 2)
    return shannonEnt

经过计算,我们可以得到我们当前使用的数据集,熵为:1.32608752536

然后,我们写一个划分数据集的函数,可以根据数据集,特征索引和特征值来划分数据集:

def splitDataSet(dataSet, axis, value):#按照特征值划分数据集,参数为数据集,特征索引,特征值
    retDataSet = []
    for featVec in dataSet:
        if featVec[axis] == value:
            reducedFeatVec = featVec[:axis]
            reducedFeatVec.extend(featVec[axis+1:])
            retDataSet.append(reducedFeatVec)
    return retDataSet

说到取最佳特征值,我们就要提到一个概念信息增益(information divergence)

他的公式是:

即将单独一个特征值提取出来,计算该特征值每个分支划分出数据集的熵的求和,然后用总数据集的熵减去它

计算四个特征值的信息增益我们得到以下数据:

0:0.0393965036461
1:0.0395108354236
2:0.377005230011
3:0.548794940695

以下是计算信息增益的代码:

def chooseBestFeatureToSplit(dataSet):#选择最佳分割特征值
    numFeatures =  len(dataSet[0]) - 1
    baseEntropy = calcShannonEnt(dataSet)
    bestInfoGain = 0.0
    bestFeature = -1
    for i in range(numFeatures):
        featList = [example[i] for example in dataSet]
        uniqueVals = set(featList)
        newEntropy = 0.0
        for value in uniqueVals:
            subDataSet = splitDataSet(dataSet, i, value)
            prob = len(subDataSet) / float(len(dataSet))
            newEntropy += prob * calcShannonEnt(subDataSet)
        infoGain = baseEntropy - newEntropy
        print(str(i)+':'+str(infoGain))
        if (infoGain > bestInfoGain):
            bestInfoGain = infoGain
            bestFeature = i
    return bestFeature

通过计算我们可以得出特征值的优先级,tear production rate>astigmatic>spectacle prescription>age of patient

接下来,有了以上的计算函数,我们就可以开始创建决策树了,创建决策树,我们使用字典类型去存储,用键代表分支节点,值代表下一个节点或者叶子节点,代码如下:

def createTree(dataSet, labels):#创建决策树
    classList = [example[-1] for example in dataSet]
    if classList.count(classList[0]) == len(classList):
        print(classList[0])
        return classList[0]
    if len(dataSet[0]) == 1:
        return majorityCnt(classList)
    bestFeat = chooseBestFeatureToSplit(dataSet)
    bestFeatLabel = labels[bestFeat]
    myTree = {bestFeatLabel:{}}
    del(labels[bestFeat])
    featValues = [example[bestFeat] for example in dataSet]
    uniqueVals = set(featValues)

    for value in uniqueVals:
        subLabels = labels[:]
        myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet, bestFeat, value), subLabels)
    return myTree

def majorityCnt(classList):#对于单个特征值的列表,按出现次数进行排序
    classCount = {}
    for vote in classList:
        if vote not in classCount.keys(): classCount[vote] = 0
        classCount[vote] += 1
    sortedClassCount = sorted(classCount.iteritems(), key = operator.itemgetter(1), reverse = True)
    return sortedClassCount[0][0]

主要函数写完以后,我们写一段测试代码,打印我们创建出的决策树:

import trees
import treePlotter
from numpy import *

lensesData, labels, vector = trees.createLensesDataSet()
parentAgeList, spectacleList, astigmaticList, tearRateList = trees.createLensesAttributeInfo()
lensesAttributeList = [parentAgeList, spectacleList, astigmaticList, tearRateList]

for i in range(len(lensesData)):
    for j in range(len(lensesData[i])):
        index = int(lensesData[i][j]) - 1
        lensesData[i][j] = lensesAttributeList[j][index]
    lensesData[i].append(str(vector[i]))

myTree = trees.createTree(lensesData, labels)
print(myTree)

我们看一下输出:

{'tear production rate': {'reduced': 'no lenses', 'normal': {'astigmatic': {'yes': {'spectacle prescription': {'hyper': {'age of the patient': {'pre': 'no lenses', 'presbyopic': 'no lenses', 'young': 'hard'}}, 'myope': 'hard'}}, 'no': {'age of the patient': {'pre': 'soft', 'presbyopic': {'spectacle prescription': {'hyper': 'soft', 'myope': 'no lenses'}}, 'young': 'soft'}}}}}}

可以看出这是一个比较长的字典嵌套结构,但是这样看上去很不直观,为了让这个决策树能直观的显示出来,我们要导入图形化模块matplotlib,用来把决策树画出来。

我们新写一个treePlotter脚本,脚本中添加计算决策树叶节点数量及深度的函数,用以计算画布的高宽布局。通过计算两个节点中点坐标的函数,确定分支属性的位置,最终画出决策树。以下是脚本代码:

import matplotlib.pyplot as plt
import matplotlib

from pylab import *
mpl.rcParams['font.sans-serif'] = ['SimHei']

# 定义文本框和箭头格式
decisionNode = dict(boxstyle = "sawtooth", fc = "0.8")
leafNode = dict(boxstyle = "round4", fc = "0.8")
arrow_args = dict(arrowstyle = "<-")

def plotNode(nodeTxt, centerPt, parentPt, nodeType):
    createPlotPlus.ax1.annotate(nodeTxt, xy = parentPt, xycoords = 'axes fraction', xytext = centerPt, textcoords = 'axes fraction', \
                            va = "center", ha = "center", bbox = nodeType, arrowprops = arrow_args)

def getNumLeafs(myTree):#获取叶节点的总数量
    numLeafs = 0
    firstStr = myTree.keys()[0]
    secondDict = myTree[firstStr]
    for k in secondDict.keys():
        if type(secondDict[k]).__name__ == 'dict':#判断节点数据类型是否为字典
            numLeafs += getNumLeafs(secondDict[k])
        else:
            numLeafs += 1
    return numLeafs

def getTreeDepth(myTree):#判断决策树的深度
    maxDepth = 0
    firstStr = myTree.keys()[0]
    secondDict = myTree[firstStr]
    for k in secondDict.keys():
        if type(secondDict[k]).__name__ == 'dict':  # 判断节点数据类型是否为字典
            thisDepth = 1 + getTreeDepth(secondDict[k])
        else:
            thisDepth = 1

        if thisDepth > maxDepth:
            maxDepth = thisDepth
    return maxDepth

def plotMidText(cntrPt, parentPt, txtString):#计算给定两个坐标的中点坐标
    xMid = (parentPt[0] - cntrPt[0]) / 2.0 + cntrPt[0]
    yMid = (parentPt[1] - cntrPt[1]) / 2.0 + cntrPt[1]
    createPlotPlus.ax1.text(xMid-0.05, yMid, txtString, rotation = 30)

def plotTree(myTree, parentPt, nodeTxt):#根据树,父节点,节点文本,绘制一个分支节点
    numLeafs = getNumLeafs(myTree)
    firstStr = myTree.keys()[0]
    cntrPt = (plotTree.xOff +(1.0 + float(numLeafs)) / 2.0 /plotTree.totalW, plotTree.yOff)
    plotMidText(cntrPt, parentPt, nodeTxt)
    plotNode(firstStr, cntrPt, parentPt, decisionNode)
    secondDict = myTree[firstStr]
    plotTree.yOff = plotTree.yOff - 1.0/plotTree.totalD
    for k in secondDict.keys():
        if type(secondDict[k]).__name__ =='dict':
            plotTree(secondDict[k], cntrPt, str(k))
        else:
            plotTree.xOff = plotTree.xOff + 1.0/plotTree.totalW
            plotNode(secondDict[k], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode)
            plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(k))
    plotTree.yOff = plotTree.yOff + 1.0 / plotTree.totalD

def createPlotPlus(inTree):#根据给定决策树创建图像
    fig = plt.figure(1, facecolor='white')
    fig.clf()
    axprops = dict(xticks = [], yticks = [])
    createPlotPlus.ax1 = plt.subplot(111, frameon = False, **axprops)
    plotTree.totalW = float(getNumLeafs(inTree))
    plotTree.totalD = float(getTreeDepth(inTree))
    plotTree.xOff = -0.5 / plotTree.totalW
    plotTree.yOff = 1.0
    plotTree(inTree, (0.5, 1.0), '')
    plt.show()

经过这个脚本的处理,我们在测试代码上调用创建决策树图像的函数:

treePlotter.createPlotPlus(myTree)

得到最终图像:

以上,完成。

参考书籍:《机器学习实战》

猜你喜欢

转载自blog.csdn.net/OneWord233/article/details/83380815