基于python实现的决策树算法-ID3

使用Python 实现决策树算法,主要是对ID3算法的实现

决策树算法的理论知识,这里不做介绍,网上很多,可以看该博客的,写的很详细,也有例子

虽然也有很多人也使用Python 写了,但是这里我自己写了,是加深自己对决策树的理解,这里做个记录吧。

# -*- coding: utf-8 -*-
"""
Created on Tue May  5 10:02:36 2020

@author: Administrator
"""
# =============================================================================
# 构建决策树的整体思路如下:
#第一步:计算信息熵;
#第二步:根据特征来划分数据集,并计算每个特征对应的信息熵,选择信息熵最优的,重新划分数据集;
#第三步:重复第二步的内容,直到判断条件:①每个分支下的所有实例都具有相同的分类;②程序遍历完所有划分数据集的属性。
#       若满足条件①,那说明可以很好的分类,若不能满足条件①,即数据集已经处理了所有的属性,但类标签依然不是唯一,
#       通常采用表决的方法决定改叶子节点的分类,即输出分类数量最多的类别
# =============================================================================

from math import log
import operator
import pickle

def majorityCnt(classList):#表决的方法决定改叶子节点的分类
    classCount={}
    for vote in classList:
        if vote not in classCount.keys(): classCount[vote]=0
        classCount+=1
    sorteClassCount=sorted(classCount.iteritems(),key=operator.itemgetter(1),reversed=True)
    return sorteClassCount[0][0]

def calcShannonEnt(dataSet):#计算信息熵
    num=len(dataSet)
    labelCounts={}
    for featVec in dataSet:
        currentLabel=featVec[-1]   #取类别,即最后一列
        if currentLabel not in labelCounts.keys():   #判断其他的类别是否在类别列表中
            labelCounts[currentLabel]=0
        labelCounts[currentLabel] +=1  #统计每个类别的个数  
    shannonEnt=0.0
    for key in labelCounts:
        prob=float(labelCounts[key])/num
        shannonEnt -= prob*log(prob,2)
    return shannonEnt


def splitDataSet(dataSet,axis,value): #重新划分数据集
    retDataSet=[]
    for featVec in dataSet:
        if featVec[axis]==value:
            reducedFeatVec=featVec[:axis]
            reducedFeatVec.extend(featVec[axis+1:])
            retDataSet.append(reducedFeatVec)
    return retDataSet


def chooseBestFeatureToSplit(dataSet):
    numFeatures=len(dataSet[0])-1  #基本信息熵在按照最后一个特征先计算好了,因此这里减去1
    baseEntropy=calcShannonEnt(dataSet)#基本信息熵
    bestInfoGain=0.0
    bestFeature=-1
    for i in range(numFeatures):
        featList=[example[i] for example in dataSet]  #列出第i列的每一行特征
        uniqueVals=set(featList)   #set函数可以去除重复项
        newEntropy=0.0
        for value in uniqueVals:
            subDataSet=splitDataSet(dataSet,i,value)  #基于第i特征和value特征重新划分数据集
            prob=len(subDataSet)/float(len(dataSet))  #这个for循环是计算第i特征划分所得到的信息熵
            newEntropy +=prob*calcShannonEnt(subDataSet)
        infoGain=baseEntropy-newEntropy  #
        if(infoGain>bestInfoGain):#得到最好的信息增益
            bestInfoGain=infoGain
            bestFeature=i
    return bestFeature


def createTree(dataSet,labels):
    classList=[example[-1] for example in dataSet]
    if classList.count(classList[0]) == len(classList):#类别完全相同则停止继续划分
        return classList[0]
    if len(dataSet)==1:     #遍历完所有特征时返回出现次数最多的
        return majorityCnt(classList)
    bestFeat=chooseBestFeatureToSplit(dataSet)
    bestFeatLabel=labels[bestFeat]
    myTree={bestFeatLabel:{}}
    del(labels[bestFeat])  #删除最优特征对于的标签
    featValues=[example[bestFeat] for example in dataSet]
    uniqueVals=set(featValues)
    for value in uniqueVals:
        subLabels=labels[:]
        myTree[bestFeatLabel][value] =createTree(splitDataSet(dataSet,bestFeat,value),subLabels)
    return myTree

def createDataSet():
    dataSet=[[1,1,'yes'],[1,1,'yes'],[1,0,'no'],[0,1,'no'],[0,1,'no']]
    labels=['no surfacing','flippers']
    return dataSet,labels

#预测使用
def classify(inputTree,featLabels,testVec):
    firstStr=list(inputTree.keys())[0]
    secondDict=inputTree[firstStr]
    featIndex=featLabels.index(firstStr)
    for key in secondDict.keys():
        if testVec[featIndex] ==key:
            if type(secondDict[key]).__name__=='dict':
                classLabel=classify(secondDict[key],featLabels,testVec)
            else:
                classLabel=secondDict[key]
    return classLabel
 
#决策树存储
def storeTree(inputTree,filename):
    fw=open(filename,'wb')
    pickle.dump(inputTree,fw)
    fw.close()
    
#还原树结构    
def grabTree(filename):
    fr=open(filename,'rb')
    return pickle.load(fr)

if __name__ == "__main__":
    dataSet,label=createDataSet()
    label1=label.copy()        #因为createTree()使用了del函数,因此这里将标签复制一份
    myTree=createTree(dataSet,label1)
    storeTree(myTree,'classfiyTree.txt')
    tree=grabTree('classfiyTree.txt')
    print(tree)
    print(classify(myTree,label,[1,1]))

猜你喜欢

转载自blog.csdn.net/qq_33047753/article/details/105951109