【机器学习】决策树(基于ID3算法)—— python3 实现方案

def calcShannonEnt(dataSet):
    '''
    计算数据集的香农熵
    :param dataSet: 包含标签的数据集,shape=m*(n+1),m是样本数,n是特征数量
    :return: 返回数据集的香农熵
    '''
    m = len(dataSet)
    d = {}  # 用于统计标签的数量
    for example in dataSet:
        target = example[-1]
        if target not in d:
            d[target] = 0
        d[target] += 1
    ShannonEnt = 0
    for label in d:
        prob = d[label] / m
        ShannonEnt += -prob * np.log2(prob)
    return ShannonEnt


def splitDataSet(dataSet, feat, value):
    '''
    按特征的值,划分数据集
    :param dataSet: 包含标签的数据集,shape=m*(n+1),m是样本数,n是特征数量
    :param feat: 设为划分点的样本特征
    :param value:  特征的值
    :return: 划分好的数据集
    '''
    splitdata = []
    for i in range(len(dataSet)):
        if dataSet[i][feat] == value:
            reducedfeat = dataSet[i][: feat]
            reducedfeat.extend(dataSet[i][feat + 1:])
            splitdata.append(reducedfeat)
    return splitdata


def chooseBestFeatureToSplit(dataSet):
    '''
    对每个特征的每个特征值进行划分,选出特征增量最大的特征
    :param dataSet: dataSet: 包含标签的数据集,shape=m*(n+1),m是样本数,n是特征数量
    :return: 最合适特征的索引值i
    '''
    n = len(dataSet[0]) - 1
    baseEnt = calcShannonEnt(dataSet)
    bestFeature = -1
    bestInfoGain = 0
    for i in range(n):
        values = [sample[i] for sample in dataSet]
        values = set(values)
        newEnt = 0
        for value in values:
            splitdata = splitDataSet(dataSet, i, value)
            prob = len(splitdata) / len(dataSet)
            newEnt += prob * calcShannonEnt(splitdata)
        infoGain = baseEnt - newEnt
        if infoGain > bestInfoGain:
            bestInfoGain = infoGain
            bestFeature = i
    return bestFeature


def classtarget(targetlist):
    '''
    统计标签列表中个列表的个数,返回个数最多的标签
    :param classlist: 标签列表
    :return: 个数最多的标签
    '''
    d = {}  # 用于记录各标签的个数
    for target in targetlist:
        if target not in d:
            d[target] = 0
        d[target] += 1
    return max(d, d.get)


def creatTree(dataSet, feature_label):
    '''
    利用递归,创建最佳分类的决策树
    :param dataSet: 包含标签的数据集,shape=m*(n+1),m是样本数,n是特征数量
    :param feature_label: 实际操作时,使用的是特征的索引值, 这个表是 索引-特征的含义 对照表,方便人类去理解
    :return: 决策树
    '''
    targetlist = [example[-1] for example in dataSet]  # 取出所有样本的标签, 存储在列表中, 记为标签列表
    if len(set(targetlist)) == 1:  # 即标签列表中只有一个类别,返回此类别
        return targetlist[0]
    if len(dataSet[0]) == 1:  # 对应 没有特征值可分的情况
        return classtarget(targetlist)  # 返回出现次数最多的类别

    bestfeature = chooseBestFeatureToSplit(dataSet)  # 选取最佳分类特征索引值
    bestfeature_label = feature_label[bestfeature]  # 获取其含义
    featlabel_copy = feature_label.copy()
    del featlabel_copy[bestfeature]  # 因为这个表要传递给子树使用,所以删去表中的这个元素(不然会导致索引值混乱,从而无法对应正确的特征)
    mytree = {bestfeature_label: {}}  # 创建根节点

    values = [example[bestfeature] for example in dataSet]
    values = set(values)
    for value in values:  # 针对最佳分类特征的每一个属性值,创建子树
        sublabel = featlabel_copy[:]  # 更新子 特征-含义 列表
        mytree[bestfeature_label][value] = creatTree(splitDataSet(dataSet, bestfeature, value), sublabel) # 递归方法创建子树
    return mytree


decisionNode = dict(boxstyle='sawtooth', fc='0.8')  # 决策节点锯齿形
leafNode = dict(boxstyle='round4', fc='0.8')  # 叶子节点椭圆形
arrow_args = dict(arrowstyle='<-')


def plotNode(nodeTxt, centerPt, parentPt, nodeType):
    '''
    定义文本框和箭头格式
    :param nodeTxt:  节点文本
    :param centerPt:  箭头的终点
    :param parentPt:  箭头的起点
    :param nodeType:  节点的类型
    :return: 无返回值。
    '''
    createPlot.ax1.annotate(nodeTxt, xy=parentPt, xycoords='axes fraction', xytext=centerPt, textcoords='axes fraction',
                            va='center', ha='center', bbox=nodeType, arrowprops=arrow_args)


def getNumleafs(myTree):
    '''
    计算决策树叶子节点的数量,递归方法
    :param myTree: 决策树
    :return: 叶子节点的数量
    '''
    numleafs = 0
    rootnode = list(myTree.keys())[0]  # 获取根节点
    secondnode = myTree[rootnode]  # 获取根节点下一层的节点
    for node in secondnode.keys():
        if type(secondnode[node]) == dict:  # 如果该节点还有下一层,则调用原函数,递归
            numleafs += getNumleafs(secondnode[node])
        else:
            numleafs += 1  # 如果已经是叶子节点了,则不用递归了,返回1
    return numleafs


def getTreeDepth(myTree):
    '''
    计算决策树的深度
    :param myTree: 决策树
    :return: 返回决策树的深度
    '''
    maxdepth = 0
    rootnode = list(myTree.keys())[0]
    secondnode = myTree[rootnode]
    for node in secondnode.keys():
        if type(secondnode[node]) == dict:
            thisDepth = 1 + getTreeDepth(secondnode[node])
        else:
            thisDepth = 1
        if thisDepth > maxdepth:
            maxdepth = thisDepth
    return maxdepth


def plotMidText(cntrPt, parentPt, txtString):
    '''
    在父子节点间填充文本信息
    :param cntrPt: 子节点坐标
    :param parentPt:  父节点坐标
    :param txtString: 文本
    :return: 无返回值
    '''
    xMid = (parentPt[0] - cntrPt[0]) / 2.0 + cntrPt[0]  # 获取横坐标
    yMid = (parentPt[1] - cntrPt[1]) / 2.0 + cntrPt[1]  # 获取纵坐标
    createPlot.ax1.text(xMid, yMid, txtString)  # 在ax1中画出这个点


def plotTree(myTree, parentPt, nodeTxt):
    '''
    计算结果图的深度,宽度 位置等等
    :param myTree: 决策树
    :param parentPt: 父节点坐标
    :param nodeTxt: 文本
    :return:
    '''
    numLeafs = getNumleafs(myTree)  # 计算宽与高
    depth = getTreeDepth(myTree)
    firstStr = list(myTree.keys())[0]
    cntrPt = (plotTree.xOff + (1.0 + float(numLeafs)) / 2.0 / plotTree.totalW, plotTree.yOff)
    plotMidText(cntrPt, parentPt, nodeTxt)
    plotNode(firstStr, cntrPt, parentPt, decisionNode)  # 标记子结点属性值
    secondDict = myTree[firstStr]
    plotTree.yOff = plotTree.yOff - 1.0 / plotTree.totalD  # 减少y偏移
    for key in secondDict.keys():
        if type(secondDict[key]) == dict:
            plotTree(secondDict[key], cntrPt, str(key))
        else:
            plotTree.xOff = plotTree.xOff + 1.0 / plotTree.totalW
            plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode)
            plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key))
    plotTree.yOff = plotTree.yOff + 1.0 / plotTree.totalD


def createPlot(myTree):
    '''
    绘制决策树
    :param myTree: 决策树
    :return: 无返回值
    '''
    fig = plt.figure(1, facecolor='white')  # 创建画板
    fig.clf()  # 清理画板
    axprops = dict(xticks=[], yticks=[])
    createPlot.ax1 = plt.subplot(111, frameon=False, **axprops)
    plotTree.totalW = float(getNumleafs(myTree))  # 获取宽度
    plotTree.totalD = float(getTreeDepth(myTree))  # 获取深度
    plotTree.xOff = -0.5/plotTree.totalW; plotTree.yOff = 1.0
    plotTree(myTree, (0.5,1.0), '')
    plt.show()


def classify(myTree, featLabels, testVec):
    '''
    利用决策树进行分类
    :param: myTree: 构造好的决策树模型
    :param: featLabels: 所有的类标签
    :param: testVec: 测试数据
    :return: 分类决策结果
    '''
    firstStr = list(myTree.keys())[0]
    secondDict = myTree[firstStr]
    featIndex = featLabels.index(firstStr)
    key = testVec[featIndex]
    valueOfFeat = secondDict[key]  
    if isinstance(valueOfFeat, dict):
        classLabel = classify(valueOfFeat, featLabels, testVec)  # 非叶子节点,使用递归进入下一层
    else: 
        classLabel = valueOfFeat  # 到达叶子节点,返回其值
    return classLabel

猜你喜欢

转载自blog.csdn.net/zhenghaitian/article/details/81054767