机器学习_决策树

#决策树

import matplotlib.pyplot as plt
from math import log
import operator
from matplotlib import font_manager
font = font_manager.FontProperties(fname=r"c:\windows\fonts\SimHei.ttf")
def createDataSet():
    # dataSet = [['青绿', '蜷缩', '浊响', '清晰', '凹陷', '硬滑', 'yes'],
    #            [1, 0, 1, 0, 0, 0, 'yes'],
    #            [1, 0, 0, 0, 0, 0, 'yes'],
    #            [0, 0, 1, 0, 0, 0, 'yes'],
    #            [2, 0, 0, 0, 0, 0, 'yes'],
    #            [0, 1, 0, 0, 1, 1, 'yes'],
    #            [1, 1, 0, 1, 1, 1, 'yes'],
    #            [1, 1, 0, 0, 1, 0, 'yes'],
    #            [1, 1, 1, 1, 1, 0, 'no'],
    #            [0, 2, 2, 0, 2, 1, 'no'],
    #            [2, 2, 2, 2, 2, 0, 'no'],
    #            [2, 0, 0, 2, 2, 1, 'no'],
    #            [0, 1, 0, 1, 0, 0, 'no'],
    #            [2, 1, 1, 1, 0, 0, 'no'],
    #            [1, 1, 0, 0, 1, 1, 'no'],
    #            [2, 0, 0, 2, 2, 0, 'no'],
    #            [0, 0, 1, 1, 1, 0, 'no']]
    dataSet = [['青绿', '蜷缩', '浊响', '清晰', '凹陷', '硬滑', 'yes'],
               ['乌黑', '蜷缩', '沉闷', '清晰', '凹陷', '硬滑', 'yes'],
               ['乌黑', '蜷缩', '浊响', '清晰', '凹陷', '硬滑', 'yes'],
               ['青绿', '蜷缩', '沉闷', '清晰', '凹陷', '硬滑', 'yes'],
               ['浅白', '蜷缩', '浊响', '清晰', '凹陷', '硬滑', 'yes'],
               ['青绿', '稍蜷', '浊响', '清晰', '稍凹', '软粘', 'yes'],
               ['乌黑', '稍蜷', '浊响', '稍糊', '稍凹', '软粘', 'yes'],
               ['乌黑', '稍蜷', '浊响', '清晰', '稍凹', '硬滑', 'yes'],
               ['乌黑', '稍蜷', '沉闷', '稍糊', '稍凹', '硬滑', 'no'],
               ['青绿', '硬挺', '清脆', '清晰', '平坦', '软粘', 'no'],
               ['浅白', '硬挺', '清脆', '模糊', '平坦', '硬滑', 'no'],
               ['浅白', '蜷缩', '浊响', '模糊', '平坦', '软粘', 'no'],
               ['青绿', '稍蜷', '浊响', '稍糊', '凹陷', '硬滑', 'no'],
               ['浅白', '稍蜷', '沉闷', '稍糊', '凹陷', '硬滑', 'no'],
               ['乌黑', '稍蜷', '浊响', '清晰', '稍凹', '软粘', 'no'],
               ['浅白', '蜷缩', '浊响', '模糊', '平坦', '硬滑', 'no'],
               ['青绿', '蜷缩', '沉闷', '稍糊', '稍凹', '硬滑', 'no']]
    labels = ['色泽', '根蒂', '敲声', '纹理', '脐部', '触感']
    return dataSet, labels
def createTree(dataset,labels,featLabels):
    """创建根结点并实现其基本流程"""
    classList = [example[-1] for example in dataset] 
    # 设置递归停止条件
    # 如果数据集很纯净,就返回当前类别yes和no,即一种特征就可以判断出来是好瓜好事坏瓜
    if (classList.count(classList[0])) == len(classList):
        return classList[0] 
    # 如果只剩下一列特征不能继续再划分时,停止递归
    if len(dataset[0]) == 1:
        return majorityCnt(classList)
    # 计算信息熵,取出最优属性,得出其索引值
    bestFeat = chooseBestFeatureToSplit(dataset)
    bestFeatLabel = labels[bestFeat] # 最优根结点
    featLabels.append(bestFeatLabel) # 将其添加到featLabels参数中,储存每个分支的特征标签,即对于每个节点featLabels中的一个元素对应于该节点所选择的特征
    myTree = {
    
    bestFeatLabel:{
    
    }}
    # 删除该属性
    del labels[bestFeat]
    featValue = [example[bestFeat] for example in dataset]
    # 取出唯一值
    uniqueVals = set(featValue)
    # 利用其特征进行分叉
    for value in uniqueVals:
        sub_lables = labels[:]
        # 递归分叉
        myTree[bestFeatLabel][value] = createTree(splitDataSet(dataset,bestFeat,value),sub_lables,featLabels)
    return myTree
def createTree(dataset,labels,featLabels):
    """创建根结点并实现其基本流程"""
    classList = [example[-1] for example in dataset] 
    # 设置递归停止条件
    # 如果数据集很纯净,就返回当前类别yes和no,即一种特征就可以判断出来是好瓜好事坏瓜
    if (classList.count(classList[0])) == len(classList):
        return classList[0] 
    # 如果只剩下一列特征不能继续再划分时,停止递归
    if len(dataset[0]) == 1:
        return majorityCnt(classList)
    # 计算信息熵,取出最优属性,得出其索引值
    bestFeat = chooseBestFeatureToSplit(dataset)
    bestFeatLabel = labels[bestFeat] # 最优根结点
    featLabels.append(bestFeatLabel) # 将其添加到featLabels参数中,储存每个分支的特征标签,即对于每个节点featLabels中的一个元素对应于该节点所选择的特征
    myTree = {
    
    bestFeatLabel:{
    
    }}
    # 删除该属性
    del labels[bestFeat]
    featValue = [example[bestFeat] for example in dataset]
    # 取出唯一值
    uniqueVals = set(featValue)
    # 利用其特征进行分叉
    for value in uniqueVals:
        sub_lables = labels[:]
        # 递归分叉
        myTree[bestFeatLabel][value] = createTree(splitDataSet(dataset,bestFeat,value),sub_lables,featLabels)
    return myTree
def splitDataSet(dataset, axis, val):
    retDataSet = []
    for featVec in dataset:
        if featVec[axis] == val:
            reduceFeatVec = featVec[:axis]
            reduceFeatVec.extend(featVec[axis + 1:])
            retDataSet.append(reduceFeatVec)
    return retDataSet
# 用于切分数据集。函数的输入参数包括:数据集(dataset)、切分数据集的特征(axis)和需要返回的特征值(val)。
# 函数的输出是一个列表,其中包含了数据集中所有特征值为val的数据行,并且这些行已经去掉了特征值为axis的那一列。
# 函数通过遍历数据集中的每一行,判断该行的axis列是否等于val。
# 如果是,就把该行的axis列去掉,并将其它列组成一个新的列表redceFeatVec,然后将这个列表添加到retDataSet中。
# 最后,函数返回retDataSet。
def calcShannonEnt(dataset):
    """计算熵值"""
    numexamples = len(dataset) # 总体数据
    labelCounts = {
    
    }
    # 取出yes和no便于后面计算概率
    for featVec in dataset:
        currentlabel = featVec[-1]
        if currentlabel not in labelCounts.keys():labelCounts[currentlabel] = 0
        labelCounts[currentlabel] += 1
    shannonEnt = 0
    # 计算熵值
    for key in labelCounts:
        prop = float(labelCounts[key])/numexamples
        shannonEnt -= prop * log(prop,2)
    return shannonEnt
def majorityCnt(classList):
    """计算多数类别是哪一个"""
    classCount = {
    
    }
    # 这个for循环是将classList中的yes和no统计出来保存在字典中
    for vote in classList:
        # 如果不在这个字典里面就将他的key设置为零
        if vote not in classCount.keys():classCount[vote] = 0
        # 在的话就+=1
        classCount[vote] +=1
    sortedclassCount = sorted(classCount.items(),key = operator.itemgetter(1),reverse=True)
    # print(sortedclassCount)
    return sortedclassCount[0][0]

def getNumLeafs(myTree):
    numLeafs = 0
    firstStr = next(iter(myTree))
    secondDict = myTree[firstStr]
    for key in secondDict.keys():
        if type(secondDict[key]).__name__ == 'dict':
            numLeafs += getNumLeafs(secondDict[key])
        else:
            numLeafs += 1
    return numLeafs
def getTreeDepth(myTree):
    maxDepth = 0
    firstStr = next(iter(myTree))
    secondDict = myTree[firstStr]
    for key in secondDict.keys():
        if type(secondDict[key]).__name__ == 'dict':
            thisDepth = 1 + getTreeDepth(secondDict[key])
        else:
            thisDepth = 1
        if thisDepth > maxDepth: maxDepth = thisDepth
    return maxDepth
# 这段代码是一个递归函数,计算输入的决策树中叶节点的数量。
# 1、函数获取决策树的根节点,并进入其子树。
# 2、遍历该子树的每个分支,如果当前节点是一个字典类型,则递归调用该函数,继续遍历其子节点。
# 如果当前节点是叶节点,即不再包含子节点,则将叶节点的数量加1。
# 3、函数返回整个决策树中叶节点的总数。
def plotNode(nodeTxt, centerPt, parentPt, nodeType):
    arrow_args = dict(arrowstyle="<-")
    # font = FontProperties(fname=r"C:\Windows\Fonts\Corbel.ttf", size=14)
    createPlot.ax1.annotate(nodeTxt, xy=parentPt, xycoords='axes fraction',
                            xytext=centerPt, textcoords='axes fraction',
                            va="center", ha="center", bbox=nodeType, arrowprops=arrow_args,fontproperties=font)

# 是用来计算一个嵌套字典的深度(也可以理解为树的深度)。
# 其中,输入参数`myTree`是一个嵌套字典,表示一个树形结构,每个节点都是一个字典。函数返回值是这个树的深度。
# 1. 初始化变量`maxDepth`为0,表示当前树的深度为0。
# 2. 从字典`myTree`中获取第一个键值对,即根节点。将根节点的值(一个字典)赋值给变量`secondDict`。
# 3. 遍历`secondDict`中的每个键,判断对应的值是否为字典。如果是字典,则递归调用`getTreeDepth`函数,计算以该节点为根的子树的深度。
# 否则,该节点为叶子节点,深度为1。
# 4. 将当前节点的深度`thisDepth`与当前最大深度`maxDepth`比较,更新`maxDepth`为较大值。
# 5. 返回最大深度`maxDepth`。

# def plotMidText(cntrPt, parentPt, txtString):
#     xMid = (parentPt[0] - cntrPt[0]) / 2.0 + cntrPt[0]
#     yMid = (parentPt[1] - cntrPt[1]) / 2.0 + cntrPt[1]
#
#     createPlot.ax1.text(xMid, yMid, txtString, va="center", ha="center", rotation=30)
def plotMidText(centerPt, parentPt, txtString):
    xMid = (parentPt[0] - centerPt[0]) / 2.0 + centerPt[0]
    yMid = (parentPt[1] - centerPt[1]) / 2.0 + centerPt[1]
    createPlot.ax1.text(xMid, yMid, txtString, va="center", ha="center", rotation=30,fontproperties=font)
# 该函数用于在父节点和子节点之间画出标签,显示它们之间的关系。具体解释如下:
# - centerPt:当前节点的坐标
# - parentPt:父节点的坐标
# - txtString:标签内容
# - xMid:标签的x坐标,计算方式为父节点和子节点的x坐标的平均值
# - yMid:标签的y坐标,计算方式为父节点和子节点的y坐标的平均值
# - va:标签的垂直对齐方式,"center"表示居中对齐
# - ha:标签的水平对齐方式,"center"表示居中对齐
# - rotation:标签的旋转角度,30表示旋转30度
# - fontproperties:字体属性,用于设置标签的字体大小、颜色等。
    
def plotTree(myTree, parentPt, nodeTxt):
    decisionNode = dict(boxstyle="sawtooth", fc="0.8")
    leafNode = dict(boxstyle="round4", fc="0.8")
    numLeafs = getNumLeafs(myTree)
    depth = getTreeDepth(myTree)
    firstStr = next(iter(myTree))
    cntrPt = (plotTree.xOff + (1.0 + float(numLeafs)) / 2.0 / plotTree.totalW, plotTree.yOff)
    plotMidText(cntrPt, parentPt, nodeTxt)
    plotNode(firstStr, cntrPt, parentPt, decisionNode)
    secondDict = myTree[firstStr]
    plotTree.yOff = plotTree.yOff - 1.0 / plotTree.totalD
    for key in secondDict.keys():
        if type(secondDict[key]).__name__ == 'dict':
            plotTree(secondDict[key], cntrPt, str(key))
        else:
            plotTree.xOff = plotTree.xOff + 1.0 / plotTree.totalW
            plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode)
            plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key))
    plotTree.yOff = plotTree.yOff + 1.0 / plotTree.totalD
    
# 代码用来绘制决策树的。
# 首先,定义了两个字典decisionNode和leafNode,分别表示决策节点和叶子节点的样式。
# 然后,获取决策树的叶子节点数量和深度,以及根节点的属性名firstStr。
# 接着,计算当前节点的中心位置cntrPt,并调用plotMidText和plotNode函数绘制节点的文本和样式。
# 然后,获取该节点的子节点secondDict,遍历其所有键值对,如果值是一个字典,则说明该节点不是叶子节点,需要递归调用plotTree函数来绘制其子节点。
# 如果值不是一个字典,则说明该节点是叶子节点,需要调用plotNode和plotMidText函数绘制该节点的样式和文本。
# 最后,更新plotTree的yOff值,以便绘制下一个节点。
def createPlot(inTree):
    fig = plt.figure(1, facecolor='white')  # 创建fig
    fig.clf()  # 清空fig
    axprops = dict(xticks=[], yticks=[])
    createPlot.ax1 = plt.subplot(111, frameon=False, **axprops)  # 去掉x、y轴
    plotTree.totalW = float(getNumLeafs(inTree))  # 获取决策树叶结点数目
    plotTree.totalD = float(getTreeDepth(inTree))  # 获取决策树层数
    plotTree.xOff = -0.5 / plotTree.totalW
    plotTree.yOff = 1.0  # x偏移
    plotTree(inTree, (0.5, 1.0), '')  # 绘制决策树
    plt.show()

# 创建并显示决策树的可视化图形的。
# 首先,创建一个白色背景的fig对象,并清空该对象。
# 然后,定义一个字典axprops,用于去掉x、y轴的刻度。
# 接着,创建一个子图ax1,并将其frameon属性设置为False,以便去掉边框。
# 接下来,获取决策树的叶子节点数量和深度,并初始化plotTree的xOff和yOff值。
# 最后,调用plotTree函数绘制决策树,并显示图形。
if __name__ == '__main__':
    dataset, labels = createDataSet()
    featLabels = []
    myTree = createTree(dataset, labels, featLabels)
    createPlot(myTree)

# 这段代码是用来测试决策树算法的。
# 首先,调用createDataSet函数生成一个简单的数据集和标签。
# 然后,定义一个空列表featLabels,用于存储决策树的属性标签。
# 接着,调用createTree函数生成决策树,并将属性标签存储在featLabels中。
# 最后,调用createPlot函数绘制决策树的可视化图形。

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-wVraJ1HS-1683703643289)(output_14_0.png)]


猜你喜欢

转载自blog.csdn.net/cfy2401926342/article/details/130602096