本文已参与「新人创作礼」活动,一起开启掘金创作之路。
引言
决策树是机器学习中一个常见的算法模型,本文将从一个简单的离散区间的分类问题分析决策树的建立过程,并使用python手动实现简易的决策树
背景
本示例以面试为背景,选取了三个指标:是否为985,学历,编程语言来评判是否能被录取。
决策树建立过程
1)将所有指标添加进指标列表
2)计算指标列表中每个指标的熵,并按照最小熵的指标进行划分
3)如果能够成功划分则直接得出结果,否则将该指标从指标列表移除后执行步骤2)并将结果添加到决策树中
注:在本文中用字典来表示树的结构(多叉树)
依赖包
import numpy as np
from collections import Counter
from math import log2
复制代码
计算损失
信息熵
def entropy(y_label):
counter = Counter(y_label)
ent = 0.0
for num in counter.values():
p = num / len(y_label)
ent += -p * log2(p)
return ent
复制代码
这里除了使用信息熵之外还可以使用基尼系数代替
基尼系数
def geni(y_label):
counter = Counter(y_label)
g = 1
for num in counter.values():
p = num / len(y_label)
g -= p * p
return g
复制代码
定义决策树
class DecisionTree:
def __init__(self):
self.tree = {}
#训练决策树
def fit(self,X,y):
cols = list(range(X.shape[1]))
#对X得每一列数据,计算分割后得信息熵
self.tree = self._genTree(cols, X, y)
#递归生成决策树
def _genTree(self, cols, X, y):
# 计算最小信息熵得特征
imin = cols[0] # 最下熵得列
emin = 100 # 最小熵值
for i in cols:
coli = X[:,i]#拿到第i个特征数据
enti = sum([entropy(y[coli==d]) for d in set(coli)]) # (也可以使用基尼系数计算,下同)
if enti < emin:
imin = i
emin = enti
#根据最小熵特征有几个值,就生成几个新的子树分支
newtree={}
mincol = X[:,imin]
cols.remove(imin)
#针对这个特征得每个值,进一步划分树
for d in set(mincol):
entd = entropy(y[mincol==d]) # 计算信息熵
if entd <1e-10:#已经完全分开
newtree[d] = y[mincol==d][0]
else:#还需要进一步细分
newtree[d] = self._genTree(cols.copy(), X[mincol==d, :], y[mincol==d])
return {imin: newtree}#将列号作为索引,返回新生成的树
#预测新样本
def predict(self, X):
X = X.tolist()
y = [None for i in range(len(X))]
for i in range(len(X)):
predictDict = self.tree
while predictDict != 'Yes' and predictDict != 'No':
col = list(predictDict.keys())[0]
predictDict = predictDict[col]
predictDict = predictDict[X[i][col]]
else:
y[i] = predictDict
return y
复制代码
测试
X=np.array([['Yes985','本科','C++'],
['Yes985','本科','Java'],
['No985' ,'硕士','Java'],
['No985' ,'硕士','C++'],
['Yes985','本科','Java'],
['No985' ,'硕士','C++'],
['Yes985','硕士','Java'],
['Yes985','博士','C++'],
['No985' ,'博士','Java'],
['No985' ,'本科','Java']])
y=np.array(['No','Yes','Yes','No','Yes','No','Yes','Yes','Yes','No'])
dt = DecisionTree()
dt.fit(X, y)
print(dt.tree)
print(dt.predict(X))
复制代码
可以看到在训练数据上测试的结果是非常好的,大家可以自己制造一些数据去测试一下。
代码解释
entropy函数:用于计算信息熵,其中Counter用来统计不同类别的数量,每种类别的数量除以总数得到概率,再通过概率求的信息熵。 DecisonTree.fit:用于训练模型,在这个函数中建立树的根结点。 DecisonTree._genTree:建立树,建立过程如上文所述 DecisonTree.predict:进行预测,在决策树中只有最后的预测结果是叶子节点。 预测过程: 1)选取指标 2)根据选取的指标,并且按照指标子节点与数据对应的值建立选取新的子结点 3)当子节点是指标时,使用该指标重复过程2),当子节点是预测结果时预测结束