决策树算法实现

# -*- coding: utf-8 -*-
"""
Created on Fri Feb 23 12:05:30 2018

@author: Administrator
"""
import math
#定义一个求熵的函数,输入为一维数组,输出为熵。
def getShang(X):
    temp={k:X.count(k) for k in set(X)}
    s=0
    for k,v in temp.items():
        s=s+(v/len(X))*(-math.log2(v/len(X)))
    return s
#定义一个求最优特征的函数,输入为一个二维数组,输出为最优特征所在的列
def getBestIndex(X):
    T=[]
    #遍历数组X的每一列,计算每个特征列对应的熵,放到数组Shang中
    for i in range(len(X[0])-1):
        #取出数组X的每一列(除了最后一列),定义临时数组temp1
        temp1=[x[i] for x in X]
        #遍历数组temp1,得到每个元素以及该元素出现的次数,以字典temp2来储存
        temp2={k:temp1.count(k) for k in set(temp1)}
        shang=0
        for k,v in temp2.items():
            temp3=[]
            for j in range(len(temp1)):
                if temp1[j]==k:
                    #Y的index与temp1的index相同
                    temp3.append([x[len(X[0])-1] for x in X][j])
            shang=shang+v/len(X)*getShang(temp3)
        T.append(shang)
    bestIndex=T.index(min(T))
    return bestIndex
#定义一个求决策树的函数,输入为二维数组加分类标签,输出为决策树
def fit(X,lable):
    tree={}
    #最好特征所在的位置为bestIndex
    bestIndex=getBestIndex(X)
    #最好的位置对应的特征列,记为bestList
    bestList=[x[bestIndex] for x in X]
    #对最好特征列去重,得到PureBestList
    pureBestList=list(set(bestList))
    branch={}
    tree[lable[bestIndex]]=branch
    #
    for v in pureBestList:
        #temp4用来承装利用v对X分组得到的结果
        temp4=[]
        for i in range(len(bestList)):
            if bestList[i]==v:
                temp4.append(X[i][-1])
        #如果能分完
        if getShang(temp4)==0:
            branch[v]=temp4[0]
        else:
            #对数组X划分子集subX,划分标准:在X中提取最好特征列中元素为v时所有的行
            #第一步:在X中提取最好特征列中元素为v时所有的行,得到sub1X
            sub1X=[x for x in X if x[bestIndex]==v]
            #第二步:删除最好特征列,得到subX
            subX=[[x[i] for i in range(len(x)) if i!=bestIndex] for x in sub1X] 
            #对Lable求子集,得到subLable
            subLable=[lable[i] for i in range(len(lable)) if i!=bestIndex]
            #迭代
            branch[v]=fit(subX,subLable)  
    return tree
X= [[1,1,1,'yes'],
    [1,0,0,'no'],
    [0,1,1,'no'],
    [0,0,1,'no'],
    [1,1,0,'yes'],
    [0,1,1,'no'],
    [0,0,0,'yes']]
lable =['态度','技能','学费']
print(fit(X,lable)) 

猜你喜欢

转载自blog.csdn.net/qq_41424519/article/details/81740283
今日推荐