Python decision tree math library ID3 algorithm

Do it once a week to improve yourself.

The development of artificial intelligence in python is in full swing, and the company arranges tasks again, and begins to learn with the fearless attitude of the ignorant.

Build the environment:

    python:2.7.14

    Editor: sublime text3

Installation steps and methods: https://my.oschina.net/wangzonghui/blog/1603104

Install the learning library:

    pip install numpy

    pip install scipy

    pip install matplotlib

    pip install scikit-learn

 

Decision tree is an entry-level supervised algorithm for artificial intelligence. It is relatively easy to understand from a program perspective. The main process is to generate a set of corresponding algorithms based on the characteristics and results of training data, evaluate test data, and generate conclusions.

The mathematical theory of decision tree supports Shannon entropy (information entropy) and Gini impurity . I am not very interested in advanced mathematics, and I do not understand it very well. If you are interested, you can use Baidu.

The mainstream algorithms are as follows:

ID3: The first generation, flawed, basic generation, the best entry. Select the feature partitioning data with the largest gain.

The expansion board of C45:ID3 adds IDC function. Select the feature partition data for the largest gain ratio.

C50: An optimized version of C45, with higher accuracy and efficiency. is a commercial software that is not available to the public.

CART: Classification regression tree, which uses Gini impurity to determine the division. The difference between it and C45 is that 1. The leaf node is not a specific classification, but a function, which defines the regression function under this condition. 2. CART is a binary tree, not a polytree.

Except for CART, all other implementations are based on Shannon entropy.

Here is the ID3 code:

#!/usr/bin/python
#coding:utf-8

from math import log
import operator

#计算给定数据集的香农熵
def calcShannonEnt(dataSet):
    numEntries=len(dataSet)  #数据集长度
    lableCounts={}
    for featVec in dataSet:
        currentLable=featVec[-1]  #数据集最后一个标签,特征数量-1
        if currentLable not in lableCounts.keys():
            lableCounts[currentLable]=0
        lableCounts[currentLable]+=1 #判断标签是否在当前字典的键值中,是键值为1
    shannonEnt=0
    for key in lableCounts:
        prob=float(lableCounts[key])/numEntries  #计算响应标签概率
        shannonEnt -=prob* log(prob,2)            #计算香农熵并且返回
    return shannonEnt

#创建简单的数据集   武器类型(0 步枪 1机枪),子弹(0 少 1多),血量(0 少,1多)  fight战斗 1逃跑 
def createDataSet():
    dataSet =[[1,1,0,'fight'],[1,0,1,'fight'],[1,0,1,'fight'],[1,0,1,'fight'],[0,0,1,'run'],[0,1,0,'fight'],[0,1,1,'run']]
    lables=['weapon','bullet','blood']
    return dataSet,lables

#按行打印数据集
def printData(myData):
    for item in myData:
        print '%s' %(item)

#给定特征划分数据集  dataSet(数据)  axis(数据下标)  value(数据值)
def splitDataSet(dataSet,axis,value):
    retDataSet=[]
    for featVec in dataSet:
        if featVec[axis] == value:
            reducedFeatVec = featVec[:axis]
            reducedFeatVec.extend(featVec[axis+1:])
            retDataSet.append(reducedFeatVec)
    return retDataSet

#选择最好的数据集划分方式 
#用这个方法必须满足条件:1、数据必须是一种由列元素组成的列表   2、所有列表元素都要具有相同的数据长度 
#3、数据的最后一列是当前数据的类别标签
def chooseBestFeatureToSplit(dataSet):
    numFeatures=len(dataSet[0])-1  #数据集最后一个标签,特征数量-1
    baseEntropy=calcShannonEnt(dataSet)
    bestInfoGain=0
    bestFeature=-1
    for i in range(numFeatures):
        featList=[example[i] for example in dataSet]
        uniqueVals =set(featList)
        newEntropy=0
        for value in uniqueVals:
            subDataSet =splitDataSet(dataSet,i,value)
            prob=len(subDataSet)/float(len(dataSet))
            newEntropy+=prob * calcShannonEnt(subDataSet)
        infoGain = baseEntropy - newEntropy
        if(infoGain > bestInfoGain):
            bestInfoGain=infoGain
            bestFeature=i
    return bestFeature


#构建决策树
def createTree(dataSet,lables):
    classList=[example[-1] for example in dataSet]
    if classList.count(classList[0])==len(classList):
        return classList[0]
    if len(dataSet[0] )== 1:
        return majorityCnt(classList)
    bestFeat = chooseBestFeatureToSplit(dataSet)
    bestFeatLable = lables[bestFeat]
    myTree={bestFeatLable:{}}
    del(lables[bestFeat])
    featValues=[example[bestFeat] for example in dataSet]
    uniqueVals = set(featValues)
    for value in uniqueVals:
        subLables = lables[:]
        myTree[bestFeatLable][value] = createTree(splitDataSet(dataSet,bestFeat,value),subLables)
    return myTree


#挑选出现次数最多的分类名称
def majorityCnt(classList):
    classCount=[]
    for vote in classList:
        if vote not in classCount.keys():classCount[vote]=0
        classCount[vote] +=1
    sortedClassCount = sorted(ClassCount.iteritems(),key=operator.itemgetter(1),reverse=true)
    return sortedClassCount[0][0]

#使用决策树分类
def classify(inputTree,featLabels,testVec):
    firstStr=inputTree.keys()[0]
    secondDict=inputTree[firstStr]
    featIndex=featLabels.index(firstStr)
    for key in secondDict.keys():
        if testVec[featIndex] ==key:
            if type(secondDict[key]).__name__=='dict':
                classLabel=classify(secondDict[key],featLabels,testVec)
            else:classLabel=secondDict[key]
    return classLabel

#存储决策树
def storeTree(inputTree,filename):
    import pickle
    fw=open(filename,'w')
    pickle.dump(inputTree,fw)
    fw.close()


#获取决策树
def grabTree(filename):
    import pickle
    fr=open(filename)
    return pickle.load(fr)


######################运行测试######################
myDat,lables=createDataSet()
# printData(myDat)    #打印数据
#result=calcShannonEnt(myDat) #计算香农熵值
#print(result)

#根据传入特征拆分数据
# jiqiang=splitDataSet(myDat,0,1)
# printData(jiqiang)
# buqiang=splitDataSet(myDat,0,0)
# printData(buqiang)

# 比较所有特征的信息增益,返回最好特征划分的索引值
#chooseBestFeatureToSplit(myDat) #根据结果显示 武器类型值最大,是最好的用于划分数据集的特征

#构建决策树  {'weapon': {0: {'blood': {0: 'fight', 1: 'run'}}, 1: 'fight'}}
#该例子中包含了3个叶子节点和2个判断节点
tree=createTree(myDat,lables)
print(tree)

#根据决策树,预测结果
# dat,lab=createDataSet()
# result=classify(tree,lab,[1,0,0])
# print "结果是",result

Run the result directly. Learn every day, enjoy unlimited happiness.

 

Guess you like

Origin http://43.154.161.224:23101/article/api/json?id=325156874&siteId=291194637