def calcShannonEnt(dataSet):
'''
计算数据集的香农熵
:param dataSet: 包含标签的数据集,shape=m*(n+1),m是样本数,n是特征数量
:return: 返回数据集的香农熵
'''
m = len(dataSet)
d = {} # 用于统计标签的数量
for example in dataSet:
target = example[-1]
if target not in d:
d[target] = 0
d[target] += 1
ShannonEnt = 0
for label in d:
prob = d[label] / m
ShannonEnt += -prob * np.log2(prob)
return ShannonEnt
def splitDataSet(dataSet, feat, value):
'''
按特征的值,划分数据集
:param dataSet: 包含标签的数据集,shape=m*(n+1),m是样本数,n是特征数量
:param feat: 设为划分点的样本特征
:param value: 特征的值
:return: 划分好的数据集
'''
splitdata = []
for i in range(len(dataSet)):
if dataSet[i][feat] == value:
reducedfeat = dataSet[i][: feat]
reducedfeat.extend(dataSet[i][feat + 1:])
splitdata.append(reducedfeat)
return splitdata
def chooseBestFeatureToSplit(dataSet):
'''
对每个特征的每个特征值进行划分,选出特征增量最大的特征
:param dataSet: dataSet: 包含标签的数据集,shape=m*(n+1),m是样本数,n是特征数量
:return: 最合适特征的索引值i
'''
n = len(dataSet[0]) - 1
baseEnt = calcShannonEnt(dataSet)
bestFeature = -1
bestInfoGain = 0
for i in range(n):
values = [sample[i] for sample in dataSet]
values = set(values)
newEnt = 0
for value in values:
splitdata = splitDataSet(dataSet, i, value)
prob = len(splitdata) / len(dataSet)
newEnt += prob * calcShannonEnt(splitdata)
infoGain = baseEnt - newEnt
if infoGain > bestInfoGain:
bestInfoGain = infoGain
bestFeature = i
return bestFeature
def classtarget(targetlist):
'''
统计标签列表中个列表的个数,返回个数最多的标签
:param classlist: 标签列表
:return: 个数最多的标签
'''
d = {} # 用于记录各标签的个数
for target in targetlist:
if target not in d:
d[target] = 0
d[target] += 1
return max(d, d.get)
def creatTree(dataSet, feature_label):
'''
利用递归,创建最佳分类的决策树
:param dataSet: 包含标签的数据集,shape=m*(n+1),m是样本数,n是特征数量
:param feature_label: 实际操作时,使用的是特征的索引值, 这个表是 索引-特征的含义 对照表,方便人类去理解
:return: 决策树
'''
targetlist = [example[-1] for example in dataSet] # 取出所有样本的标签, 存储在列表中, 记为标签列表
if len(set(targetlist)) == 1: # 即标签列表中只有一个类别,返回此类别
return targetlist[0]
if len(dataSet[0]) == 1: # 对应 没有特征值可分的情况
return classtarget(targetlist) # 返回出现次数最多的类别
bestfeature = chooseBestFeatureToSplit(dataSet) # 选取最佳分类特征索引值
bestfeature_label = feature_label[bestfeature] # 获取其含义
featlabel_copy = feature_label.copy()
del featlabel_copy[bestfeature] # 因为这个表要传递给子树使用,所以删去表中的这个元素(不然会导致索引值混乱,从而无法对应正确的特征)
mytree = {bestfeature_label: {}} # 创建根节点
values = [example[bestfeature] for example in dataSet]
values = set(values)
for value in values: # 针对最佳分类特征的每一个属性值,创建子树
sublabel = featlabel_copy[:] # 更新子 特征-含义 列表
mytree[bestfeature_label][value] = creatTree(splitDataSet(dataSet, bestfeature, value), sublabel) # 递归方法创建子树
return mytree
decisionNode = dict(boxstyle='sawtooth', fc='0.8') # 决策节点锯齿形
leafNode = dict(boxstyle='round4', fc='0.8') # 叶子节点椭圆形
arrow_args = dict(arrowstyle='<-')
def plotNode(nodeTxt, centerPt, parentPt, nodeType):
'''
定义文本框和箭头格式
:param nodeTxt: 节点文本
:param centerPt: 箭头的终点
:param parentPt: 箭头的起点
:param nodeType: 节点的类型
:return: 无返回值。
'''
createPlot.ax1.annotate(nodeTxt, xy=parentPt, xycoords='axes fraction', xytext=centerPt, textcoords='axes fraction',
va='center', ha='center', bbox=nodeType, arrowprops=arrow_args)
def getNumleafs(myTree):
'''
计算决策树叶子节点的数量,递归方法
:param myTree: 决策树
:return: 叶子节点的数量
'''
numleafs = 0
rootnode = list(myTree.keys())[0] # 获取根节点
secondnode = myTree[rootnode] # 获取根节点下一层的节点
for node in secondnode.keys():
if type(secondnode[node]) == dict: # 如果该节点还有下一层,则调用原函数,递归
numleafs += getNumleafs(secondnode[node])
else:
numleafs += 1 # 如果已经是叶子节点了,则不用递归了,返回1
return numleafs
def getTreeDepth(myTree):
'''
计算决策树的深度
:param myTree: 决策树
:return: 返回决策树的深度
'''
maxdepth = 0
rootnode = list(myTree.keys())[0]
secondnode = myTree[rootnode]
for node in secondnode.keys():
if type(secondnode[node]) == dict:
thisDepth = 1 + getTreeDepth(secondnode[node])
else:
thisDepth = 1
if thisDepth > maxdepth:
maxdepth = thisDepth
return maxdepth
def plotMidText(cntrPt, parentPt, txtString):
'''
在父子节点间填充文本信息
:param cntrPt: 子节点坐标
:param parentPt: 父节点坐标
:param txtString: 文本
:return: 无返回值
'''
xMid = (parentPt[0] - cntrPt[0]) / 2.0 + cntrPt[0] # 获取横坐标
yMid = (parentPt[1] - cntrPt[1]) / 2.0 + cntrPt[1] # 获取纵坐标
createPlot.ax1.text(xMid, yMid, txtString) # 在ax1中画出这个点
def plotTree(myTree, parentPt, nodeTxt):
'''
计算结果图的深度,宽度 位置等等
:param myTree: 决策树
:param parentPt: 父节点坐标
:param nodeTxt: 文本
:return:
'''
numLeafs = getNumleafs(myTree) # 计算宽与高
depth = getTreeDepth(myTree)
firstStr = list(myTree.keys())[0]
cntrPt = (plotTree.xOff + (1.0 + float(numLeafs)) / 2.0 / plotTree.totalW, plotTree.yOff)
plotMidText(cntrPt, parentPt, nodeTxt)
plotNode(firstStr, cntrPt, parentPt, decisionNode) # 标记子结点属性值
secondDict = myTree[firstStr]
plotTree.yOff = plotTree.yOff - 1.0 / plotTree.totalD # 减少y偏移
for key in secondDict.keys():
if type(secondDict[key]) == dict:
plotTree(secondDict[key], cntrPt, str(key))
else:
plotTree.xOff = plotTree.xOff + 1.0 / plotTree.totalW
plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode)
plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key))
plotTree.yOff = plotTree.yOff + 1.0 / plotTree.totalD
def createPlot(myTree):
'''
绘制决策树
:param myTree: 决策树
:return: 无返回值
'''
fig = plt.figure(1, facecolor='white') # 创建画板
fig.clf() # 清理画板
axprops = dict(xticks=[], yticks=[])
createPlot.ax1 = plt.subplot(111, frameon=False, **axprops)
plotTree.totalW = float(getNumleafs(myTree)) # 获取宽度
plotTree.totalD = float(getTreeDepth(myTree)) # 获取深度
plotTree.xOff = -0.5/plotTree.totalW; plotTree.yOff = 1.0
plotTree(myTree, (0.5,1.0), '')
plt.show()
def classify(myTree, featLabels, testVec):
'''
利用决策树进行分类
:param: myTree: 构造好的决策树模型
:param: featLabels: 所有的类标签
:param: testVec: 测试数据
:return: 分类决策结果
'''
firstStr = list(myTree.keys())[0]
secondDict = myTree[firstStr]
featIndex = featLabels.index(firstStr)
key = testVec[featIndex]
valueOfFeat = secondDict[key]
if isinstance(valueOfFeat, dict):
classLabel = classify(valueOfFeat, featLabels, testVec) # 非叶子节点,使用递归进入下一层
else:
classLabel = valueOfFeat # 到达叶子节点,返回其值
return classLabel
【机器学习】决策树(基于ID3算法)—— python3 实现方案
猜你喜欢
转载自blog.csdn.net/zhenghaitian/article/details/81054767
今日推荐
周排行