决策树算法—ID3的python实现模板

首先知道模板的x与y的类型,不同与其他机器学习算法的是,这个还要加一个label来标识分类的标签

#数据集
dataSet = np.array([[0, 0, 0, 0, 'N'],
                   [0, 0, 0, 1, 'N'],
                   [1, 0, 0, 0, 'Y'],
                   [2, 1, 0, 0, 'Y'],
                   [2, 2, 1, 0, 'Y'],
                   [2, 2, 1, 1, 'N'],
                   [1, 2, 1, 1, 'Y']])
labels = np.array(['outlook', 'temperature', 'humidity', 'windy'])
#测试集
testSet = np.array([[0, 1, 0, 0],
           [0, 2, 1, 0],
           [2, 1, 1, 0],
           [0, 1, 1, 1],
           [1, 1, 0, 1],
           [1, 0, 1, 0],
           [2, 1, 0, 1]])

难点:选择最优的分类

#计算数据集信息熵
def dataset_entropy(dataset):
    """
    计算数据集信息熵
    :param dataset:
    :return:  熵值
    """
    classLabel = dataset[:,-1]
    labelCount = {}
    #计算类别个数
    for i in range(classLabel.size):
        label = classLabel[i]
        labelCount[label]=labelCount.get(label,0)+1
    #计算熵值
    ent=0
    for k,v in labelCount.items():
        ent += -v/classLabel.size*np.log2(v/classLabel.size)
    return ent
#划分子集
def splitDataSet(dataset,featureIndex,value):
    #划分后的子集
    subdataset = []
    for example in dataset:
        if example[featureIndex]==value:
            subdataset.append(example)
    return np.delete(subdataset,featureIndex,axis=1)
#选择最优分类属性
def chooseBestFeature(dataset,labels):
    #特征的个数
    featureNum = labels.size
    #最小熵值
    minEntropy,bestFeatureIndex=1,None
    #样本的总数
    n = dataset.shape[0]
    for i in range(featureNum):
        #指定特征的条件熵
        featureEntropy =0
        #返回所有子集
        featureList = dataset[:, i]
        featureValues = set(featureList)
        for value in featureValues:
            subDataSet = splitDataSet(dataset,i,value)
            featureEntropy+=subDataSet.shape[0]/n*dataset_entropy(subDataSet)
        if minEntropy>featureEntropy:
            minEntropy = featureEntropy
            bestFeatureIndex = i
    print(minEntropy)
    return bestFeatureIndex

创建决策树,树可以用字典来写,超级方便,远远由于Java,c++这些语言(也可能是我学艺不精,其他语言同样也有类似的快捷方式)

#为空时选择的最大类别
def mayorClass(classList):
    labelCount = {}
    for i in range(classList.size):
        label = classList[i]
        labelCount[label] = labelCount.get(label, 0) + 1
    sortedLabel = sorted(labelCount.items(),key=operator.itemgetter(1),reverse=True)
    return sortedLabel[0][0]
#建树
def createTree(dataset,labels):
    #如果样本全部属于同一个类别
    classList = dataset[:,-1]
    if len(set(classList))==1:
        return dataset[:,-1][0]
    #如果A=空
    if labels.size==0:
        return mayorClass(classList)

    bestFeatureIndex = chooseBestFeature(dataset,labels)
    bestFeature = labels[bestFeatureIndex]
    dtree = {bestFeature:{}}
    featureList = dataset[:,bestFeatureIndex]
    featureValues = set(featureList)
    for value in featureValues:
        subdataset = splitDataSet(dataset,bestFeatureIndex,value)
        sublabels = np.delete(labels,bestFeatureIndex)
        dtree[bestFeature][value]=createTree(subdataset,sublabels)
    return dtree

预测数据

#预测单个
def predict(tree,labels,testData):
    rootName = list(tree.keys())[0]
    rootValue = tree[rootName]
    featureIndex = list(labels).index(rootName)
    classLabel = None
    for key in rootValue.keys():
        if testData[featureIndex] == int(key):
            if type(rootValue[key]).__name__=="dict":
                classLabel = predict(rootValue[key],labels,testData)
            else:
                classLabel = rootValue[key]
    return classLabel
#预测全部
def predictAll(tree,labels,testSet):
    classLabels = []
    for i in testSet:
        classLabels.append(predict(tree,labels,i))
    return classLabels

调用函数

tree = createTree(dataset, labels)
testSet = createTestSet()
print(predictAll(tree, labels, testSet))
print(tree)

下面介绍一个画决策树的很好用的模板,这部分与决策树算法无关,但是可以将建好的树可视化
treePlotter.py

import matplotlib.pyplot as plt

decisionNode = dict(boxstyle="sawtooth", fc="0.8")
leafNode = dict(boxstyle="round4", fc="0.8")
arrow_args = dict(arrowstyle="<-")

def plotNode(nodeTxt, centerPt, parentPt, nodeType):
    createPlot.ax1.annotate(nodeTxt, xy=parentPt, xycoords='axes fraction', \
                            xytext=centerPt, textcoords='axes fraction', \
                            va="center", ha="center", bbox=nodeType, arrowprops=arrow_args)

def getNumLeafs(myTree):
    numLeafs = 0
    firstStr = list(myTree.keys())[0]
    secondDict = myTree[firstStr]
    for key in secondDict.keys():
        if type(secondDict[key]).__name__ == 'dict':
            numLeafs += getNumLeafs(secondDict[key])
        else:
            numLeafs += 1
    return numLeafs

def getTreeDepth(myTree):
    maxDepth = 0
    firstStr = list(myTree.keys())[0]
    secondDict = myTree[firstStr]
    for key in secondDict.keys():
        if type(secondDict[key]).__name__ == 'dict':
            thisDepth = getTreeDepth(secondDict[key]) + 1
        else:
            thisDepth = 1
        if thisDepth > maxDepth:
            maxDepth = thisDepth
    return maxDepth

def plotMidText(cntrPt, parentPt, txtString):
    xMid = (parentPt[0] - cntrPt[0]) / 2.0 + cntrPt[0]
    yMid = (parentPt[1] - cntrPt[1]) / 2.0 + cntrPt[1]
    createPlot.ax1.text(xMid, yMid, txtString)

def plotTree(myTree, parentPt, nodeTxt):
    numLeafs = getNumLeafs(myTree)
    depth = getTreeDepth(myTree)
    firstStr = list(myTree.keys())[0]
    cntrPt = (plotTree.xOff + (1.0 + float(numLeafs)) / 2.0 / plotTree.totalw, plotTree.yOff)
    plotMidText(cntrPt, parentPt, nodeTxt)
    plotNode(firstStr, cntrPt, parentPt, decisionNode)
    secondDict = myTree[firstStr]
    plotTree.yOff = plotTree.yOff - 1.0 / plotTree.totalD
    for key in secondDict.keys():
        if type(secondDict[key]).__name__ == 'dict':
            plotTree(secondDict[key], cntrPt, str(key))
        else:
            plotTree.xOff = plotTree.xOff + 1.0 / plotTree.totalw
            plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode)
            plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key))
    plotTree.yOff = plotTree.yOff + 1.0 / plotTree.totalD

def createPlot(inTree):
    fig = plt.figure(1, facecolor='white')
    fig.clf()
    axprops = dict(xticks=[], yticks=[])
    createPlot.ax1 = plt.subplot(111, frameon=False, **axprops)
    plotTree.totalw = float(getNumLeafs(inTree))
    plotTree.totalD = float(getTreeDepth(inTree))
    plotTree.xOff = -0.5 / plotTree.totalw
    plotTree.yOff = 1.0
    plotTree(inTree, (0.5, 1.0), '')
    plt.show()

调用

import treePlotter as treePlotter
treePlotter.createPlot(tree)

效果图:

效果图

发布了32 篇原创文章 · 获赞 76 · 访问量 4078

猜你喜欢

转载自blog.csdn.net/weixin_43981664/article/details/104314208