[机器学习实战]决策树

原理

决策树图示
通过提问的方式,根据不同的答案选择不同的分支, 完成不同的分类

步骤分解

1.遍历数据集, 循环计算提取每个特征的香农熵和信息增益, 选取信息增益最大的特征。 再递归计算剩余的特征顺序。 将特征排序。 并将分类结果序列化保存到磁盘当中

def chooseBestFeatureToSplit(dataSet):  # 选择最好的分类特征
    """
    :param dataSet: 原数据集
    :return: 最好的划分特征的索引值
    """
    numFeatures = len(dataSet[0]) - 1   # 获取特征数
    baseEntropy = calcShannonEnt(dataSet)   # 计算数据集的信息熵
    bestInfoGain = 0.0      # 初始化最好的信息熵
    bestFeature = -1        # 初始化最好的用于分割的特征
    for i in range(numFeatures):
        # 创建唯一的分类标签列表
        featList= [example[i] for example in dataSet]   # 获取每个元素的第i个特征
        uniqueVals = set(featList)  # 数据特征去重 (此特征有几种情况)
        newEntropy = 0.0

        # 计算每种划分方式的信息熵
        for value in uniqueVals:
            subDataSet = splitDataSet(dataSet, i, value)
            prob = len(subDataSet) / float(len(dataSet))    # probability,概率,可理解为权重
            newEntropy += prob * calcShannonEnt(subDataSet)
        infoGain = baseEntropy - newEntropy     # 新的熵越小即新划分的数据集混乱程度越小,与原熵的差值就越大, 即信息增益就越大

        # 计算最好的信息增益
        if(infoGain > bestInfoGain):    # 若新的信息增益大于之前的信息增益,则替换
            bestInfoGain = infoGain
            bestFeature = i     # 表示最好的划分特征的索引值
    return bestFeature

2.递归构建决策树

def createTree(dataSet, labels):
    """
    :param dataSet: 数据集
    :param labels: 标签列表, 包含了数据集中的所有特征的标签
    :return:
    """
    classList = [example[-1] for example in dataSet]

    # 类别完全相同则停止继续划分
    if classList.count(classList[0]) == len(classList):
        return classList[0]

    # 遍历完所有特征,仍然不能将数据集划分成包含唯一类别的分组时,返回出现次数最多的类别
    if len(dataSet[0]) == 1:
        return majorityCnt(classList)

    bestFeat = chooseBestFeatureToSplit(dataSet)    # 选取的最好特征
    bestFeatLabel = labels[bestFeat]    # 最好特征标签
    myTree = {bestFeatLabel:{ }}    # 使用字典存储树信息

    # 得到列表包含的所有属性值
    del(labels[bestFeat])
    featValues = [example[bestFeat] for example in dataSet]
    uniqueVals = set(featValues)

    for value in uniqueVals:
        subLabels = labels[:]   # 因为下一步传参数时是引用传参
        myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet, bestFeat, value), subLabels)

    return myTree

3.使用Matplotlib注解绘制树形图

import matplotlib.pyplot as plt
import trees

# 定义文本框和箭头格式
decisionNode = dict(boxstyle="sawtooth", fc="0.8")
leafNode = dict(boxstyle="round4", fc="0.8")    # 设置箭头的样式和背景色
arrow_args = dict(arrowstyle="<-")  # 设置箭头的样式

def plotNode(nodeTxt, centerPt, parentPt, nodeType): # 绘制带箭头的注解
    createPlot.axl.annotate(nodeTxt, xy=parentPt, xycoords='axes fraction',
                            xytext=centerPt, textcoords='axes fraction',
                            va="center", ha="center", bbox=nodeType, arrowprops=arrow_args)

def createPlot(inTree):
    fig = plt.figure(1, facecolor='white')  # 设置背景色
    fig.clf()   # 清空画布
    axprops = dict(xticks=[], yticks=[])
    createPlot.axl = plt.subplot(111, frameon=False, **axprops) #表示图中有1行1列,绘图放在第几列, 有无边框
    plotTree.totalW = float(trees.getNumLeafs(inTree))
    plotTree.totalD = float(trees.getTreeDepth(inTree))
    plotTree.xOff = -0.5/plotTree.totalW
    plotTree.yOff = 1.0
    plotTree(inTree, (0.5, 1.0), ' ')
    # plotNode('a decision node', (0.5, 0.5), (0.1, 0.5), decisionNode)   # 第一个坐标是注解的坐标 第二个坐标是点的坐标
    # plotNode('a leaf node', (0.8, 0.1), (0.3, 0.8), leafNode)
    plt.show()

def plotMidText(cntrPt, parentPt, txtString): # 在父子节点间填充文本信息
    xMid = (parentPt[0] - cntrPt[0]) / 2.0 + cntrPt[0]
    yMid = (parentPt[1] - cntrPt[1]) / 2.0 + cntrPt[1]
    createPlot.axl.text(xMid, yMid, txtString)

def plotTree(myTree, parentPt, nodeTxt):    #计算宽与高
    numLeafs = trees.getNumLeafs(myTree)
    depth = trees.getTreeDepth(myTree)
    firstStr = list(myTree.keys())[0]   # 根节点
    cntrPt = (plotTree.xOff + (1.0 + float(numLeafs)) / 2.0 / plotTree.totalW, plotTree.yOff)
    # 标记子节点属性值
    plotMidText(cntrPt, parentPt, nodeTxt)
    plotNode(firstStr, cntrPt, parentPt, decisionNode)
    secondDict = myTree[firstStr]
    # 减少y偏移
    plotTree.yOff = plotTree.yOff - 1.0/plotTree.totalD
    for key in secondDict.keys():
        if type(secondDict[key]).__name__ == 'dict':
            plotTree(secondDict[key], cntrPt, str(key))
        else:
            plotTree.xOff = plotTree.xOff + 1.0/plotTree.totalW
            plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode)
            plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key))
    plotTree.yOff = plotTree.yOff + 1.0/plotTree.totalD


def retrieveTree(i): # 创建树
    listOfTrees =[{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}},
                  {'no surfacing': {0: 'no', 1: {'flippers': {0: {'head': {0: 'no', 1: 'yes'}}, 1: 'no'}}}}]
    return listOfTrees[i]

完整代码

trees.py

from math import log
import operator
import treePlotter

def calcShannonEnt(dataSet):    # 计算给定数据集的香农熵
    numEntries = len(dataSet)
    labelCounts = {}

    # 为所有可能的分类创建字典
    for featVec in dataSet:
        currentLabel = featVec[-1]
        # if currentLabel not in labelCounts.keys():
        #     labelCounts[currentLabel] = 0
        # labelCounts[currentLabel] += 1
        labelCounts[currentLabel] = labelCounts.get(currentLabel, 0) + 1

    shannonEnt = 0.0
    for key in labelCounts:
        # 以2为底求对数
        prob = float(labelCounts[key]) / numEntries
        shannonEnt -= prob * log(prob, 2)

    return shannonEnt

def splitDataSet(dataSet, axis, value): # 按照给定特征划分数据集
    """
    :param dataSet: 待划分的数据集
    :param axis: 划分数据集的特征
    :param value: 特征的返回值
    :return:
    """
    # 创建新的list对象
    retDataSet = []
    for featVec in dataSet:
        if featVec[axis] == value:  # 抽取
            reducedFratVec = featVec[:axis]
            reducedFratVec.extend(featVec[axis+1:])
            retDataSet.append(reducedFratVec)
    return retDataSet

def chooseBestFeatureToSplit(dataSet):  # 选择最好的分类特征
    """
    :param dataSet: 原数据集
    :return: 最好的划分特征的索引值
    """
    numFeatures = len(dataSet[0]) - 1   # 获取特征数
    baseEntropy = calcShannonEnt(dataSet)   # 计算数据集的信息熵
    bestInfoGain = 0.0      # 初始化最好的信息熵
    bestFeature = -1        # 初始化最好的用于分割的特征
    for i in range(numFeatures):
        # 创建唯一的分类标签列表
        featList= [example[i] for example in dataSet]   # 获取每个元素的第i个特征
        uniqueVals = set(featList)  # 数据特征去重 (此特征有几种情况)
        newEntropy = 0.0

        # 计算每种划分方式的信息熵
        for value in uniqueVals:
            subDataSet = splitDataSet(dataSet, i, value)
            prob = len(subDataSet) / float(len(dataSet))    # probability,概率,可理解为权重
            newEntropy += prob * calcShannonEnt(subDataSet)
        infoGain = baseEntropy - newEntropy     # 新的熵越小即新划分的数据集混乱程度越小,与原熵的差值就越大, 即信息增益就越大

        # 计算最好的信息增益
        if(infoGain > bestInfoGain):    # 若新的信息增益大于之前的信息增益,则替换
            bestInfoGain = infoGain
            bestFeature = i     # 表示最好的划分特征的索引值
    return bestFeature

def majorityCnt(classList): # 多数表决决定叶子节点的分类
    """
    :param classList: 类别列表
    :return: 出现次数最多的分类名称
    """
    classCount = {}
    for vote in classList:  # 统计分类列表中个类别出现的次数
        # if vote not in classCount.keys(): classCount[vote] = 0
        # classCount[vote] += 1
        classCount[vote] = classCount.get(vote, 0) + 1
    sortedClassCount = sorted(classCount.items(), key=operator.itemgetter(1), reverse=True)  # 根据出现次数排序
    return sortedClassCount[0][0]

def createTree(dataSet, labels):
    """
    :param dataSet: 数据集
    :param labels: 标签列表, 包含了数据集中的所有特征的标签
    :return:
    """
    classList = [example[-1] for example in dataSet]

    # 类别完全相同则停止继续划分
    if classList.count(classList[0]) == len(classList):
        return classList[0]

    # 遍历完所有特征,仍然不能将数据集划分成包含唯一类别的分组时,返回出现次数最多的类别
    if len(dataSet[0]) == 1:
        return majorityCnt(classList)

    bestFeat = chooseBestFeatureToSplit(dataSet)    # 选取的最好特征
    bestFeatLabel = labels[bestFeat]    # 最好特征标签
    myTree = {bestFeatLabel:{ }}    # 使用字典存储树信息

    # 得到列表包含的所有属性值
    del(labels[bestFeat])
    featValues = [example[bestFeat] for example in dataSet]
    uniqueVals = set(featValues)

    for value in uniqueVals:
        subLabels = labels[:]   # 因为下一步传参数时是引用传参
        myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet, bestFeat, value), subLabels)

    return myTree

def getNumLeafs(myTree):
    numLeafs = 0
    firstStr = list(myTree.keys())[0]
    secondDict = myTree[firstStr]
    for key in secondDict.keys():
        # 测试节点的数据类型是否为字典
        if type(secondDict[key]).__name__ == 'dict':
            numLeafs += getNumLeafs(secondDict[key])
        else:
            numLeafs += 1
    return numLeafs

def getTreeDepth(myTree):
    maxDepth = 0
    firstStr = list(myTree.keys())[0]
    secondDict = myTree[firstStr]
    for key in secondDict.keys():
        if type(secondDict[key]).__name__ == 'dict':
            thisDepth = 1 + getTreeDepth(secondDict[key])
        else:
            thisDepth = 1
        if thisDepth > maxDepth:
            maxDepth = thisDepth
    return maxDepth

def createDataSet():    # 创建数据集
    dataSet = [[1, 1, 'yes'], [1, 1, 'yes'], [1, 0, 'no'], [0, 1, 'no'], [0, 1, 'no']]
    labels = ['no surfacing', 'flippers']
    return dataSet, labels

def classify(inputTree, featLabels, testVec):   # 分类器
    """
    :param inputTree: 树,即数据集
    :param featLabels: 特征标签
    :param testVec: 待测向量
    :return: 类别
    """
    firstStr = list(inputTree.keys())[0]
    # 将标签字符串转换为索引
    secondDict = inputTree[firstStr]
    featIndex = featLabels.index(firstStr)

    for key in secondDict.keys():
        if testVec[featIndex] == key:
            if type(secondDict[key]).__name__ == 'dict':
                classLabel = classify(secondDict[key], featLabels, testVec)     # 若未到叶子节点,则继续往下递归,直到叶子节点
            else:
                classLabel = secondDict[key]        # 如果已到叶子节点, 则直接取dict当前key的value
    return classLabel

def storeTree(inputTree, filename):     # 序列化保存树(分类信息)
    import pickle
    fw = open(filename, 'wb+')
    pickle.dump(inputTree, fw)
    fw.close()

def grabTree(filename):     # 读取序列化文件
    import pickle
    fr = open(filename, "rb+")
    return pickle.load(fr)

if __name__ == "__main__":
    myDat, labels = createDataSet()
    # myTree = createTree(myDat, labels)
    # print(myTree)
    print(myDat)
    myTree = treePlotter.retrieveTree(0)
    print(myTree)
    print(classify(myTree, labels, [1, 0]))
    print(classify(myTree, labels, [1, 1]))
    print("===========store tree============")
    storeTree(myTree, 'classifierStorafe.txt')
    print(grabTree('classifierStorafe.txt'))

treePlotter

import matplotlib.pyplot as plt
import trees

# 定义文本框和箭头格式
decisionNode = dict(boxstyle="sawtooth", fc="0.8")
leafNode = dict(boxstyle="round4", fc="0.8")    # 设置箭头的样式和背景色
arrow_args = dict(arrowstyle="<-")  # 设置箭头的样式

def plotNode(nodeTxt, centerPt, parentPt, nodeType): # 绘制带箭头的注解
    createPlot.axl.annotate(nodeTxt, xy=parentPt, xycoords='axes fraction',
                            xytext=centerPt, textcoords='axes fraction',
                            va="center", ha="center", bbox=nodeType, arrowprops=arrow_args)

def createPlot(inTree):
    fig = plt.figure(1, facecolor='white')  # 设置背景色
    fig.clf()   # 清空画布
    axprops = dict(xticks=[], yticks=[])
    createPlot.axl = plt.subplot(111, frameon=False, **axprops) #表示图中有1行1列,绘图放在第几列, 有无边框
    plotTree.totalW = float(trees.getNumLeafs(inTree))
    plotTree.totalD = float(trees.getTreeDepth(inTree))
    plotTree.xOff = -0.5/plotTree.totalW
    plotTree.yOff = 1.0
    plotTree(inTree, (0.5, 1.0), ' ')
    # plotNode('a decision node', (0.5, 0.5), (0.1, 0.5), decisionNode)   # 第一个坐标是注解的坐标 第二个坐标是点的坐标
    # plotNode('a leaf node', (0.8, 0.1), (0.3, 0.8), leafNode)
    plt.show()

def plotMidText(cntrPt, parentPt, txtString): # 在父子节点间填充文本信息
    xMid = (parentPt[0] - cntrPt[0]) / 2.0 + cntrPt[0]
    yMid = (parentPt[1] - cntrPt[1]) / 2.0 + cntrPt[1]
    createPlot.axl.text(xMid, yMid, txtString)

def plotTree(myTree, parentPt, nodeTxt):    #计算宽与高
    numLeafs = trees.getNumLeafs(myTree)
    depth = trees.getTreeDepth(myTree)
    firstStr = list(myTree.keys())[0]   # 根节点
    cntrPt = (plotTree.xOff + (1.0 + float(numLeafs)) / 2.0 / plotTree.totalW, plotTree.yOff)
    # 标记子节点属性值
    plotMidText(cntrPt, parentPt, nodeTxt)
    plotNode(firstStr, cntrPt, parentPt, decisionNode)
    secondDict = myTree[firstStr]
    # 减少y偏移
    plotTree.yOff = plotTree.yOff - 1.0/plotTree.totalD
    for key in secondDict.keys():
        if type(secondDict[key]).__name__ == 'dict':
            plotTree(secondDict[key], cntrPt, str(key))
        else:
            plotTree.xOff = plotTree.xOff + 1.0/plotTree.totalW
            plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode)
            plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key))
    plotTree.yOff = plotTree.yOff + 1.0/plotTree.totalD


def retrieveTree(i): # 创建树
    listOfTrees =[{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}},
                  {'no surfacing': {0: 'no', 1: {'flippers': {0: {'head': {0: 'no', 1: 'yes'}}, 1: 'no'}}}}]
    return listOfTrees[i]

if __name__ == "__main__":
    # reTree = retrieveTree(1)
    # leafs = trees.getNumLeafs(reTree)
    # depth = trees.getTreeDepth(reTree)
    # print(reTree)
    # print(leafs)
    # print(depth)
    myTree = retrieveTree(0)
    myTree['no surfacing'][3] = 'maybe'
    createPlot(myTree)

猜你喜欢

转载自blog.csdn.net/vi_nsn/article/details/78896882