python3实现决策树(机器学习实战)

from math import log
import operator

#创建数据集
def createDataSet():
    dataSet = [[1,1,'Yes'],
               [1,1,'Yes'],
               [1,1, 'No'],
               [1,0,'Yes'],
               [0,1, 'Yes'],
               [0,1,'Yes'],
               [0,0,'No']]
    labels = ['A','B']
    return dataSet,labels

#计算给定数据集的香农熵
def calcShannonEnt(dataSet):
    numEntries = len(dataSet)
    labelCount = {}
    for featVec in dataSet:
        currentLabel = featVec[-1]
        if currentLabel not in labelCount.keys():
            labelCount[currentLabel] = 0
        labelCount[currentLabel] += 1
    shannonEnt = 0.0
    for key in labelCount:
        prob = float(labelCount[key])/numEntries
        shannonEnt -= prob * log(prob,2)
    return shannonEnt

#按照给定特征划分数据集
def splitDataSet(dataSet,axis,value):
    retDataSet = []
    for featVec in dataSet:
        if featVec[axis] == value:
            reduceFeatVec = featVec[:axis]
            reduceFeatVec.extend(featVec[axis+1:])
            # reduceFeatVec = featVec  #不等同于上面两行
            retDataSet.append(reduceFeatVec)
    return retDataSet

#选择最好的数据集划分方式
def chooseBestFeatureToSplit(dataSet):
    numFeatures = len(dataSet[0]) - 1   #特征个数
    baseEntroy = calcShannonEnt(dataSet)  #整个数据集的原始香农熵
    bestInfoGain = 0.0
    bestFeature = -1
    for i in range(numFeatures):
        featList = [x[i] for x in dataSet]
        uniqueVals = set(featList)
        newEntroy = 0.0
        for value in uniqueVals:
            subDataSet = splitDataSet(dataSet,i,value)
            prob = len(subDataSet)/float(len(dataSet))
            newEntroy += prob * calcShannonEnt(subDataSet)
        infoGain = baseEntroy - newEntroy
        if infoGain > bestInfoGain:
            bestInfoGain = infoGain
            bestFeature = i
    return bestFeature

#数据集处理完所有属性,类标签依然不是唯一的,采用多数表决的方法决定该叶子节点的分类
def majorityCnt(classList):
    classCount = {}
    for vote in classList:
        if vote not in classCount.keys():
            classCount[vote] = 0
        classCount[vote] += 1
    sortedClassCount = sorted(classCount.items(),key=operator.itemgetter(1),reverse=True)
    return sortedClassCount[0][0]

#创建树
def createTree(dataSet,labels):
    # print(dataSet)
    classList = [x[-1] for x in dataSet]   #数据集的所有类标签
    if classList.count(classList[0]) == len(classList):   #所有标签完全相同
        return classList[0]
    if len(dataSet[0]) == 1:   #使用完所有特征,仍不能将数据集划分成仅包含唯一类别的分组
        return majorityCnt(classList)   #使用多数表决的方法决定该叶子节点的分类
    bestFeat = chooseBestFeatureToSplit(dataSet)   #最好特征划分的索引值
    bestFeatureLabel = labels[bestFeat]    #最好的特征标签
    myTree = {bestFeatureLabel:{}}
    featValues = [x[bestFeat] for x in dataSet]
    uniqueVals = set(featValues)
    for value in uniqueVals:
        subLlabels = labels[:]
        myTree[bestFeatureLabel][value] = createTree(splitDataSet(dataSet,bestFeat,value),subLlabels)
    return myTree

def main():
    dataSet,labels = createDataSet()  #得到数据集
    myTree = createTree(dataSet,labels)  #生成决策树
    print(myTree)

if __name__ == '__main__':
    main()

猜你喜欢

转载自blog.csdn.net/qq_42591058/article/details/88943874