自定义实现决策树——ID3算法

from  math import log
import operator
def createDataSet():
    #      房产 车  是否可能有偿还能力
    dataSet=[[1,1,'yes'],
             [1,1,'yes'],
             [1,0,'yes'],
             [0,1,'no'],
             [0,1,'yes'],
             [0,0,'no']]
    labels = ['是否有房产','是否有车']
    return dataSet,labels
dataSet,labels=createDataSet()
'''
#1.计算信息熵
#2.计算条件熵
#3.计算 信息增益=信息熵-条件熵
'''
#计算数据集的熵
def calcShannonEnt(dataSet):
    numEntries=len(dataSet)
    labelCounts={}
    #计算每个标签出现的次数
    for featVec in dataSet:
        currentLabel = featVec[-1]#得到标签
        #计算标签数
        if currentLabel not in labelCounts.keys():
            labelCounts[currentLabel] = 0
        labelCounts[currentLabel]+=1
    shannonEnt=0.0
    #计算熵
    for key in labelCounts:
        prob = float(labelCounts[key])/numEntries
        shannonEnt-=prob*log(prob,2)
    return shannonEnt
#测试以上的熵计算
dataSet,labels=createDataSet()
shan=calcShannonEnt(dataSet)
print(shan)
'''
熵值越大,则混合的数据也越多

dataSet[3][-1]='不确定'
shan=calcShannonEnt(dataSet)
print(shan)

另一个度量信息无序程度的方法是基尼不纯度,他指的是从一个数据集中随机选取子项,度量其被错误分类到其分组的概率
'''
def splitDataSet(dataSet,axis,value):
    #利用熵划分数据集:按照获取最大的信息增益的方法划分数据集
    #axis:第几特征
    #value:这个列的取值

    retDataSet = []
    for featVec in dataSet:
        if featVec[axis] == value:
            #抽取 说白了就是从列表中剔除axis(下面两行代码)
            reducedFeatVec = featVec[:axis]
            reducedFeatVec.extend(featVec[axis+1:])
            retDataSet.append(reducedFeatVec)
    return retDataSet
print(dataSet)
print("=================================================")
result=splitDataSet(dataSet,0,1)
print(result)
result=splitDataSet(dataSet,0,0)
print(result)
result=splitDataSet(dataSet,1,1)
print(result)
result=splitDataSet(dataSet,1,0)
print(result)
#遍历整个数据集,循环计算shannoent和splitDataSet()函数,找到最好的特征划分方式
def chooseBestFeatureToSplit(dataSet):
    #计算特征数量,最后一列为标签,不计入
    numFeatures=len(dataSet[0])-1 #2个特征
    #计算熵值
    baseEntropy = calcShannonEnt(dataSet)#整个数据集的信息熵
    #计算增益
    bestInfoGain = 0.0
    #最好特征的索引
    bestFeature = -1
    #循环所有的特征列
    for i in range(numFeatures):
        #取出每一列的特征值,这里是0,1
        featList = [example[i] for example in dataSet]
        #特征值去重复(要么有房,要么没房,两种)
        uniqueVals = set(featList)#set不能存重复数据
        newEntropy = 0.0
        for value in uniqueVals:#1,0(有房,没房)
            #循环计算每个列的每个特征取值,并划分子集出来
            subDataSet = splitDataSet(dataSet,i,value)
            #计算这个特征取值出现的概率
            prob = len(subDataSet)/float(len(dataSet))
            newEntropy +=prob*calcShannonEnt(subDataSet)
        infoGain = baseEntropy - newEntropy
        if (infoGain > bestInfoGain):
            bestInfoGain=infoGain
            bestFeature =i
    return bestFeature
#测试:求出数据集中可用于分割数据集的最好特征
index=chooseBestFeatureToSplit(dataSet)
print(index)
#输出结果为0
'''
分析:得到原始的数据集,然后基于最好的属性划分数据集,由于特征值可能多余两个,因此可能存在大于两个分支的数据集
划分后,数据将被向下传递到数分支的下一个节点再次划分,以上采取递归形式进行
递归结束的条件:程序遍历完所有划分数据集的属性,或者每个分之下的所有实例都具有相同的分类

以上结束条件的算法还有其他,如:c4.5,Cart

'''
def majorityCnt(classList):
    '''
    classList:分类的名称
    '''
    classCount={}#存每个分类出现的频率{‘yes’:3,'no':2}
    for vote in classList:
        classCount[vote]=classCount.get(vote,0)+1
    #排序后输出出现次数最多的分类名称,operator.itemgetter(1)表示后面的数字部分
    sortedClassCount = sorted(classCount.items(),key=operator.itemgetter(1),reverse=True)
    return sortedClassCount[0][0]
        
#取 最多的那一个标签 
dataSet,labels=createDataSet()#再生成一次数据集
print(dataSet,labels)
def createTree(dataSet,labels,min_samples_split=None,max_features=None):
    '''
    创建数:递归
    1.类别相同时停止划分
    2.遍历完所有特征时,返回出现次数最多的,
    3.得到列表包含的所有属性值
    '''
    classList = [example[-1] for example in dataSet] #取类别{'yes','yes','yes','no','no'}
    #如果类别完全相同则停止继续划分
    #TODO:这里可以指定划分结束的条件1:min_sample_split 表示所有的类别是否纯净
    
    if classList.count(classList[0])==len(classList):
        #print('classList[0]',classList[0])
        return classList[0]
    #如果只有一个特征列,也停止划分,即对应max_features
    if max_features==None:
        max_features=1
    if len(dataSet[0])<=max_features:
        return majorityCnt(classList)
    #选取最好的特征的索引
    bestFeat = chooseBestFeatureToSplit(dataSet)
    #print('bestFeat:',bestFeat)
    #根据索引取出对应的,最好的分类的特征名字
    bestFeatLabel=labels[bestFeat]
    #print('bestFeatLabel:',bestFeatLabel)
    myTree = {bestFeatLabel:{}}#构建树
    del(labels[bestFeat])#从标签中删除这个特征名(是否有房产被最先被删除)
    featValues = [example[bestFeat] for example in dataSet]
    #print(featValues)
    uniqueVals=set(featValues)#唯一的特征值
    #print(uniqueVals)
    
    for value in uniqueVals:
        subLabels = labels[:]
        #print('-------------------------------')
        #print('subLabels',subLabels)
        #print('=======================()()()()()()()()()()()()()()())()()()()(())===============')
        myTree[bestFeatLabel][value] = createTree( splitDataSet(dataSet,bestFeat,value),subLabels)
    
    return myTree
tree = createTree(dataSet,labels)
tree
#{'是否有房产': {0: {'是否有车': {0: 'no', 1: 'no'}}, 1: 'yes'}}
def classify(inputTree,featLabels,testVec):
    '''
    inputTree:决策树
    featLabels:标签向量
    testVec:测试数据
    '''
    #在python2中,dict.keys()返回一个列表,python3中dict.keys()返回一个dict_keys对象
    firstStr=list(inputTree.keys())[0]
    print('firstStr:',firstStr)
    secondDict = inputTree[firstStr]
    print('secondDict:',secondDict)
    featIndex = featLabels.index(firstStr)#将标签转换为索引 0
    print('featIndex:',featIndex)
    key = testVec[featIndex]#key=1
    print('key:',key)
    valueOfFeat = secondDict[key]#valueOfFeat: yes
    print('valueOfFeat:',valueOfFeat)
    if isinstance(valueOfFeat,dict):
        classLabel = classify(valueOfFeat,featLabels,testVec)
        print('valueOfFeat:',valueOfFeat)
        print('featLabels:',featLabels)
        print('testVec:',testVec)
    else:
        classLabel = valueOfFeat
    return classLabel
#测试:
dataSet,labels=createDataSet()
print(labels)
tree=createTree(dataSet,labels)
print(tree)
print('==================================')
dataSet,labels=createDataSet()
classify(tree,labels,[0,1])

隐形眼镜的案例:lenses.txt具体内容如下,不会上传文件大家多担待.......

young    myope    no    reduced    no lenses
young    myope    no    normal    soft
young    myope    yes    reduced    no lenses
young    myope    yes    normal    hard
young    hyper    no    reduced    no lenses
young    hyper    no    normal    soft
young    hyper    yes    reduced    no lenses
young    hyper    yes    normal    hard
pre    myope    no    reduced    no lenses
pre    myope    no    normal    soft
pre    myope    yes    reduced    no lenses
pre    myope    yes    normal    hard
pre    hyper    no    reduced    no lenses
pre    hyper    no    normal    soft
pre    hyper    yes    reduced    no lenses
pre    hyper    yes    normal    no lenses
presbyopic    myope    no    reduced    no lenses
presbyopic    myope    no    normal    no lenses
presbyopic    myope    yes    reduced    no lenses
presbyopic    myope    yes    normal    hard
presbyopic    hyper    no    reduced    no lenses
presbyopic    hyper    no    normal    soft
presbyopic    hyper    yes    reduced    no lenses
presbyopic    hyper    yes    normal    no lenses

#prescript:药方 astigmatic:散光 presbyopic 远视眼 myope近视
#使用以上代码实现预测隐形眼镜类型 
fr = open('dataset/lenses.txt')
lenses=[line.strip().split('\t') for line in fr.readlines()]
lensesLabel=['age','prescript','astigmatic','tearRate']
tree=createTree(lenses,lensesLabel)
print(tree)
lensesLabel=['age','prescript','astigmatic','tearRate']
classify(tree,lensesLabel,['presbyopic','hyper','yes','normal'])

猜你喜欢

转载自blog.csdn.net/WJWFighting/article/details/81083153