机器学习——决策树的剪枝

一、决策树剪枝

1.为什么要进行剪枝?

原因:上一篇博客是使用ID算法构建的决策树,而ID3算法构建的决策树可能存在以下问题:
① 不能对连续数据进行处理,只能通过连续数据离散化进行处理;
② 采用信息增益容易偏向取值较多的特征,准确率不如信息增益率;
③没有采用剪枝,决策树的结构可能过于复杂,容易出现过拟合。

2.什么是剪枝?

概念:剪枝主要是解决决策树出现的“过拟合”现象。剪枝就是通过某种判断,避免一些不必要的遍历过程。剪枝剪枝又分为预剪枝和后剪枝。

预剪枝:预剪枝,就是将即将发芽的分支“扼杀在萌芽状态”即在分支划分前就进行剪枝判断,如果判断结果是需要剪枝,则不进行该分支划分。

后剪枝(自底向上):在分支划分之后,通常是决策树的各个判断分支已经形成后,才开始进行剪枝判断。

二、代码实现

1.数据集准备:在上一篇文章的数据下改进新的数据集:

动物名称 食性 毛发 生活环境 哺乳动物
猫科 短毛 草原
爬行 森林
爬行 水里
两栖 多毛 草原
两栖 短毛 草原
猫科 多毛 森林
两栖 森林
两栖 水里
飞行 水里
爬行 水里
飞行 短毛 森林
爬行 森林
猫科 森林
爬行 水里
飞行 森林
两栖 水里
两栖 短毛 草原
猫科 多毛 草原
飞行 杂食 短毛 海边

2.没有剪枝前的决策树

#树的可视化

decisionNodeStyle = dict(boxstyle="sawtooth", fc="0.8")
leafNodeStyle = {"boxstyle": "round4", "fc": "0.8"}
arrowArgs = {"arrowstyle": "<-"}


# 画节点
def plotNode(nodeText, centerPt, parentPt, nodeStyle):
    createPlot.ax1.annotate(nodeText, xy=parentPt, xycoords="axes fraction", xytext=centerPt
                            , textcoords="axes fraction", va="center", ha="center", bbox=nodeStyle,
                            arrowprops=arrowArgs)


# 添加箭头上的标注文字
def plotMidText(centerPt, parentPt, lineText):
    xMid = (centerPt[0] + parentPt[0]) / 2.0
    yMid = (centerPt[1] + parentPt[1]) / 2.0
    createPlot.ax1.text(xMid, yMid, lineText)


# 画树
def plotTree(decisionTree, parentPt, parentValue):
    # 计算宽与高
    leafNum, treeDepth = getTreeSize(decisionTree)
    # 在 1 * 1 的范围内画图,因此分母为 1
    # 每个叶节点之间的偏移量
    plotTree.xOff = plotTree.figSize / (plotTree.totalLeaf - 1)
    # 每一层的高度偏移量
    plotTree.yOff = plotTree.figSize / plotTree.totalDepth
    # 节点名称
    nodeName = list(decisionTree.keys())[0]
    # 根节点的起止点相同,可避免画线;如果是中间节点,则从当前叶节点的位置开始,
    #      然后加上本次子树的宽度的一半,则为决策节点的横向位置
    centerPt = (plotTree.x + (leafNum - 1) * plotTree.xOff / 2.0, plotTree.y)
    # 画出该决策节点
    plotNode(nodeName, centerPt, parentPt, decisionNodeStyle)
    # 标记本节点对应父节点的属性值
    plotMidText(centerPt, parentPt, parentValue)
    # 取本节点的属性值
    treeValue = decisionTree[nodeName]
    # 下一层各节点的高度
    plotTree.y = plotTree.y - plotTree.yOff
    # 绘制下一层
    for val in treeValue.keys():
        # 如果属性值对应的是字典,说明是子树,进行递归调用; 否则则为叶子节点
        if type(treeValue[val]) == dict:
            plotTree(treeValue[val], centerPt, str(val))
        else:
            plotNode(treeValue[val], (plotTree.x, plotTree.y), centerPt, leafNodeStyle)
            plotMidText((plotTree.x, plotTree.y), centerPt, str(val))
            # 移到下一个叶子节点
            plotTree.x = plotTree.x + plotTree.xOff
    # 递归完成后返回上一层
    plotTree.y = plotTree.y + plotTree.yOff


# 画出决策树
def createPlot(decisionTree):
    fig = plt.figure(1, facecolor="white")
    fig.clf()
    axprops = {"xticks": [], "yticks": []}
    createPlot.ax1 = plt.subplot(111, frameon=False, **axprops)
    # 定义画图的图形尺寸
    plotTree.figSize = 1.5
    # 初始化树的总大小
    plotTree.totalLeaf, plotTree.totalDepth = getTreeSize(decisionTree)
    # 叶子节点的初始位置x 和 根节点的初始层高度y
    plotTree.x = 0
    plotTree.y = plotTree.figSize
    plotTree(decisionTree, (plotTree.figSize / 2.0, plotTree.y), "")
    plt.show()

 3.输出结果如下:

{'动物名称': {'猫科': '是', '两栖': {'生活环境': {'森林': '是', '草原': '是', '水里': '否'}}, '爬行': {'食性': {'肉': {'生活环境': {'森林': '否', '水里': {'毛发': {'无': '否'}}}}, '杂': '否'}}, '飞行': '否'}}

 结论:可以直接判断出猫科类的动物都是哺乳动物,飞行类的动物都不是哺乳动物,两栖类的动物如果生活在森林和草原就属于哺乳动物,如果生活在水里就不是哺乳动物。

4.预剪枝后的树:

创建的数据集代码

#创建数据集
def createData():
    data = np.array([['猫科', '肉', '短毛', '草原'],
    ['爬行', '肉', '无', '森林'],
    ['爬行', '杂', '无', '水里'],
    ['两栖', '草', '多毛', '草原'],
    ['两栖', '草', '短毛', '草原'],
    ['猫科', '杂', '多毛', '森林'],
    ['两栖', '草', '无', '森林'],
    ['两栖', '杂', '无', '水里'],
    ['爬行', '肉', '无', '水里'],
    ['爬行', '肉', '无', '水里'],
    ['飞行', '杂', '短毛', '森林'],
    ['爬行', '杂', '短毛', '森林'],
    ['猫科', '肉', '无', '森林'],
    ['爬行', '肉', '无', '水里'],
    ['飞行', '草', '无', '森林'],
    ['爬行', '杂', '无', '水里'],
    ['两栖', '草', '短毛', '草原'],
    ['猫科', '肉', '多毛', '草原'],
    ['飞行', '杂', '短毛', '海边']])
    label = np.array(['是', '否', '否', '是', '是', '是', '是', '否', '否', '是', '否', '否', '是', '否', '否', '否', '是', '是', '否'])
    name = np.array(['动物名称', '食性', '毛发', '生活环境'])
    return data, label, name

划分的数据集代码:其中前10个数据为训练集,后10个代码为 测试集

def splitXgData(xgData, xgLabel):
    xgDataTrain = xgData[[0, 1, 2, 5, 6, 9, 13, 14, 15, 16,17],:]
    xgDataTest = xgData[[3, 4, 7, 8, 10, 11, 12,18],:]
    xgLabelTrain = xgLabel[[0, 1, 2, 5, 6, 9, 13, 14, 15, 16,17]]
    xgLabelTest = xgLabel[[3, 4, 7, 8, 10, 11, 12,18]]
    return xgDataTrain, xgLabelTrain, xgDataTest, xgLabelTest
# 创建预剪枝决策树
def createTreePrePruning(dataTrain, labelTrain, dataTest, labelTest, names, method='id3'):

    trainData = np.asarray(dataTrain)
    labelTrain = np.asarray(labelTrain)
    labelTest = np.asarray(labelTest)
    names = np.asarray(names)
    # 如果结果为单一结果
    if len(set(labelTrain)) == 1:
        return labelTrain[0]
        # 如果没有待分类特征
    elif trainData.size == 0:
        return voteLabel(labelTrain)
    # 其他情况则选取特征
    bestFeat, bestEnt = bestFeature(dataTrain, labelTrain, method=method)
    # 取特征名称
    bestFeatName = names[bestFeat]
    # 从特征名称列表删除已取得特征名称
    names = np.delete(names, [bestFeat])
    # 根据最优特征进行分割
    dataTrainSet, labelTrainSet = splitFeatureData(dataTrain, labelTrain, bestFeat)

    # 预剪枝评估
    # 划分前的分类标签
    labelTrainLabelPre = voteLabel(labelTrain)
    labelTrainRatioPre = equalNums(labelTrain, labelTrainLabelPre) / labelTrain.size
    # 划分后的精度计算
    if dataTest is not None:
        dataTestSet, labelTestSet = splitFeatureData(dataTest, labelTest, bestFeat)
        # 划分前的测试标签正确比例
        labelTestRatioPre = equalNums(labelTest, labelTrainLabelPre) / labelTest.size
        # 划分后 每个特征值的分类标签正确的数量
        labelTrainEqNumPost = 0
        for val in labelTrainSet.keys():
            labelTrainEqNumPost += equalNums(labelTestSet.get(val), voteLabel(labelTrainSet.get(val))) + 0.0
        # 划分后 正确的比例
        labelTestRatioPost = labelTrainEqNumPost / labelTest.size

        # 如果没有评估数据 但划分前的精度等于最小值0.5 则继续划分
    if dataTest is None and labelTrainRatioPre == 0.5:
        decisionTree = {bestFeatName: {}}
        for featValue in dataTrainSet.keys():
            decisionTree[bestFeatName][featValue] = createTreePrePruning(dataTrainSet.get(featValue),
                                                                         labelTrainSet.get(featValue)
                                                                         , None, None, names, method)
    elif dataTest is None:
        return labelTrainLabelPre
        # 如果划分后的精度相比划分前的精度下降, 则直接作为叶子节点返回
    elif labelTestRatioPost < labelTestRatioPre:
        return labelTrainLabelPre
    else:
        # 根据选取的特征名称创建树节点
        decisionTree = {bestFeatName: {}}
        # 对最优特征的每个特征值所分的数据子集进行计算
        for featValue in dataTrainSet.keys():
            decisionTree[bestFeatName][featValue] = createTreePrePruning(dataTrainSet.get(featValue),
                                                                         labelTrainSet.get(featValue)
                                                                         , dataTestSet.get(featValue),
                                                                         labelTestSet.get(featValue)
                                                                         , names, method)
    return decisionTree


# 测试函数
xgData, xgLabel, xgName = createData()
xgTree = createTree(xgData, xgLabel, xgName, method = 'id3')
print(xgTree)
createPlot(xgTree)

 {'动物名称': {'飞行': '否', '猫科': '是', '两栖': {'生活环境': {'森林': '是', '水里': '否', '草原': '是'}}, '爬行': {'食性': {'肉': {'生活环境': {'森林': '否', '水里': {'毛发': {'无': '否'}}}}, '杂': '否'}}}}

扫描二维码关注公众号,回复: 15201743 查看本文章

5.后剪枝的树:

代码

# 后剪枝 训练完成后决策节点进行替换评估  这里可以直接对xgTreeTrain进行操作
def treePostPruning(labeledTree, dataTest, labelTest, names):
    newTree = labeledTree.copy()
    dataTest = np.asarray(dataTest)
    labelTest = np.asarray(labelTest)
    names = np.asarray(names)
    # 取决策节点的名称 即特征的名称
    featName = list(labeledTree.keys())[0]
    # print("\n当前节点:" + featName)
    # 取特征的列
    featCol = np.argwhere(names == featName)[0][0]
    names = np.delete(names, [featCol])
   
    # 该特征下所有值的字典
    newTree[featName] = labeledTree[featName].copy()
    featValueDict = newTree[featName]
    featPreLabel = featValueDict.pop("_vpdl")
    # print("当前节点预划分标签:" + featPreLabel)
    # 是否为子树的标记
    subTreeFlag = 0
    # 分割测试数据 如果有数据 则进行测试或递归调用  np的array我不知道怎么判断是否None, 用is None是错的
    dataFlag = 1 if sum(dataTest.shape) > 0 else 0
    if dataFlag == 1:
        # print("当前节点有划分数据!")
        dataTestSet, labelTestSet = splitFeatureData(dataTest, labelTest, featCol)
    for featValue in featValueDict.keys():
        # print("当前节点属性 {0} 的子节点:{1}".format(featValue ,str(featValueDict[featValue])))
        if dataFlag == 1 and type(featValueDict[featValue]) == dict:
            subTreeFlag = 1
            # 如果是子树则递归
            newTree[featName][featValue] = treePostPruning(featValueDict[featValue], dataTestSet.get(featValue),
                                                           labelTestSet.get(featValue), names)
            # 如果递归后为叶子 则后续进行评估
            if type(featValueDict[featValue]) != dict:
                subTreeFlag = 0

                # 如果没有数据  则转换子树
        if dataFlag == 0 and type(featValueDict[featValue]) == dict:
            subTreeFlag = 1
            
            newTree[featName][featValue] = convertTree(featValueDict[featValue])
          
    
    if subTreeFlag == 0:
        ratioPreDivision = equalNums(labelTest, featPreLabel) / labelTest.size
        equalNum = 0
        for val in labelTestSet.keys():
            equalNum += equalNums(labelTestSet[val], featValueDict[val])
        ratioAfterDivision = equalNum / labelTest.size
        
        if ratioAfterDivision < ratioPreDivision:
            newTree = featPreLabel
    return newTree

 结论:后剪枝图片运行结果与预剪枝结果相同,可能数据集出现问题,还在解决中

 三、个人总结

1.在使用画图工具时要记得加注释防止乱码

# 设置中文显示字体

from pylab import mpl

# 设置中文显示字体

mpl.rcParams["font.sans-serif"] = ["SimHei"]

 2.预剪枝和后剪枝的优缺点

预剪枝:

优点:采用贪心算法的策略,适合大规模问题。

缺点:会提前停止生长,还是可能存在欠拟合的风险

 后剪枝:

优点:可以最大限度的保留树的各个节点,避免了欠拟合的风险。

缺点:需要的运算时间较长

猜你喜欢

转载自blog.csdn.net/Gucciwei/article/details/127954359
今日推荐