1、一般流程
1)收集数据:可以使用任何方法
2)准备数据:树构造算法只适用于标称型数据,因此数值型数据必须【离散化】
3)分析数据:可以使用任何方法,构造树完成之后,我们应该检查图形是否符合预期
4)训练算法:构造树的数据结构
5)测试算法:使用经验树计算错误率
6)使用算法:此步骤可以适用于任何监督学习算法,而使用决策树可以更好地理解数据内在含义
2、决策树的优缺点
- 优点:计算复杂度不高,输出结构易于理解,对中间的缺失不敏感,可以处理不相关特征数据
- 缺点:可能会产生过度匹配的问题
- 适用数据类型:数值型和标称型
3、ID3 算法实现
# python 3.6
# 20180425
# 加载包
import math
import tqdm # 显示循环进度
import operator
# 画图包
import matplotlib.pyplot as plt
plt.style.use('ggplot') #设置绘图style
from matplotlib.font_manager import FontProperties # 中文画图
myfont = FontProperties(fname=r"c:\windows\fonts\simsun.ttc", size=14)
计算目标的熵
def Ent(dataset):
"""
计算目标的熵
:param dataset: 输入数据集 array
:return: 该数据集的目标的熵
"""
numEntries = len(dataset)
labelCounts = {}
# 将该数据集下所属分类做统计
for i in dataset:
nowlabel = i[-1]
labelCounts[nowlabel] = labelCounts.get(nowlabel,0) + 1
# 更加ID3 计算熵
sEnt = 0.0
for key in labelCounts:
prob = float(labelCounts[key])/numEntries
sEnt -= prob * math.log2(prob)
return sEnt
# 简单测试
def createdata():
dataset = [[1, 1, 'yes'],[1, 1, 'yes'],[1, 0, 'no'],
[0, 1, 'no'],[0, 1, 'no']]
labels = ['no surfacing', 'flippers']
return dataset, labels
dataset, labels = createdata()
Ent(dataset)
# 熵越高,则混合的数据越多
dataset[0][-1] = 'maybe'
Ent(dataset)
按给定特征划分数据集
def splitDataset(dataset, axis, value):
"""
按给定特征划分数据集
:param dataset: 数据集 array
:param axis: 指定第几列作为划分依据 int
:param value: 指定该列某值的划分
:return: 划分后的结果 array/list
"""
retdataset = []
for i in dataset:
if i[axis] == value:
tmp = i[:axis]
tmp.extend(i[axis+1:])
retdataset.append(tmp)
return retdataset
splitDataset(dataset, 1, 1)
选择最佳划分的特征
def BestSplit(dataset):
"""
选择最佳划分的特征
:param dataset: 数据集 array
:return: 信息增益最大的特征列 int
"""
numFeatures = len(dataset[0]) - 1
targetEnt = Ent(dataset) # 目标分类的熵
bestInfoGain = 0.0
bestFeature = -1
for i in tqdm.tqdm(range(numFeatures)): # 一列列循环
# 创建唯一分类的标签列表
featlist = [j[i] for j in dataset] # 取第i列的所有特征
uniqueVals = set(featlist)
# 求每种划分方式的信息熵
newEntropy = 0.0
for value in uniqueVals :
subdataset = splitDataset(dataset, i, value) # 提取第i列某个特征的集合
prob = len(subdataset) / float(len(dataset))
newEntropy += prob * Ent(subdataset)
# 选择信息增益最大的特征列
InfoGain = targetEnt - newEntropy
if (InfoGain > bestInfoGain):
bestInfoGain = InfoGain
bestFeature = i
return bestFeature
多数表决法
def majorCnt(classList):
"""
当最后分类结束时,叶还不纯,用多数类代表该叶的类
:param classList: 类的名称
:return: 出现次数最多的类名
"""
classcount = {}
for vote in classList:
classcount[vote] = classcount.get(vote,0) + 1
sortedclasscount = sorted(classcount.items(), key = operator.itemgetter(1),reverse = True)
return sortedclasscount[0][0]
创建决策树
def createTree(dataset, labels):
"""
递归生成树
:param dataset: 数据集 array
:param labels: 特征标签列表 list
:return: 树
"""
classList = [j[-1] for j in dataset]
# 递归结束判断——类别全同,即熵为0时不划分,返回该类标签
if classList.count(classList[0]) == len(classList):
return classList[0]
# 递归结束判断——遍历完所有特征,返回:出现次数最多的类别
if len(dataset[0]) == 1:
return majorCnt(classList)
bestFeat = BestSplit(dataset)
bestFeatlabel = labels[bestFeat]
del(labels[bestFeat]) # 删除最佳划列的特征标签
featvalues = [j[bestFeat] for j in dataset]
uniquevalues = set(featvalues)
myTree = {bestFeatlabel:{}}
for value in uniquevalues:
sublabels = labels[:]
myTree[bestFeatlabel][value] = createTree(splitDataset(dataset, bestFeat, value),sublabels)
return myTree
4、绘制决策树
定义文本框和箭头格式
# 判断节点框
decisionnode = dict(boxstyle = 'sawtooth', fc='0.8',ec = 'b', lw = 1)
# 叶节点框
leafnode = dict(boxstyle = 'round4', fc = '0.8')
# 箭头格式
arrow_args = dict(arrowstyle = "<|-",ec = 'black', lw = 1 )
绘制带箭头的注解
def plotNode(nodeTxt, centerPt, parentPt, nodeType):
"""
绘制带箭头的注解
:param nodeTxt: 节点文字
:param centerPt: 文本中心
:param parentPt: 指向文本中心的点的位置
:param nodeType: 节点样式
"""
plt.annotate(nodeTxt, xy = parentPt, xycoords = 'axes fraction',
xytext = centerPt, textcoords = 'axes fraction',
ha='center',
va='center',
# size=15,
bbox= nodeType,
arrowprops= arrow_args, # 箭头的格式
fontproperties=myfont # 中文输入问题
)
消除量纲影响:
def autoNorm(dataset):
"""
将每个字段数据归一化
:param dataset: 训练集样本 array
:return: 归一化后的数据集 array;
"""
minv = dataset.min(0)
maxv = dataset.max(0)
ranges = maxv - minv
m = dataset.shape[0]
normDateset = dataset - tile(minv,(m,1))
normDateset = normDateset/tile(ranges,(m,1)) # 特征值相除
return normDateset, ranges, minv
normDateset, ranges, minv =autoNorm(datingDataMAT)
确定树的大小
# == 构造注解树
def getNumleafs(mytree):
# 计算决策树最后有多少叶
numLeaf = 0
firstStr = list(mytree.keys())[0]
secondDict = mytree[firstStr]
for key in secondDict.keys():
if type(secondDict[key]) == dict:
# 如果子节点为字典类型,则为一个判断节点,需要用递归调用
numLeaf += getNumleafs(secondDict[key])
else:
numLeaf += 1
return numLeaf
def getTreeDepth(mytree):
# 计算决策树最后有多少节点
maxDepth = 0
firstStr = list(mytree.keys())[0]
secondDict = mytree[firstStr]
for key in secondDict.keys():
# 为判断节点就加一并继续遍历
if type(secondDict[key]) == dict:
thisDepth = 1 + getTreeDepth(secondDict[key])
else:
thisDepth = 1
if thisDepth > maxDepth :
maxDepth = thisDepth
return maxDepth
# 便于测试
list_tree = [{'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}},
{'no surfacing': {0: 'no', 1: {'flippers': {0: {'head':{0:'no', 1: 'yes'}},1:'no'}}}}]
getNumleafs(list_tree[1])
getTreeDepth(list_tree[1])
绘制决策树
def createPlot(inTree):
"""
创建决策树图示
:param inTree: 决策树 dict
:return: 决策树图示
"""
fig = plt.figure(1, facecolor='white')
fig.clf()
axprops = dict(xticks=[], yticks=[])# 定义坐标轴为空
createPlot.ax1 = plt.subplot(111, frameon=False, **axprops)
createPlot.ax1 = plt.subplot(111, frameon=False)
#全局变量--树的大小
plotTree.totalW = float(getNumleafs(inTree))
plotTree.totalD = float(getTreeDepth(inTree))
#例如绘制3个叶子结点,坐标应为1/3,2/3,3/3
#但这样会使整个图形偏右因此初始的,将x值向左移一点,即取-0.5 而非 1。
plotTree.xOff = -0.5/plotTree.totalW
plotTree.yOff = 1.0
plotTree(inTree, (0.5,1.0), '')
plt.show()
def plotTree(myTree, parentPt, nodeTxt):
"""
绘制树的局部 辅助于createTree
:param myTree: 决策树 dict
:param parentPt: 指向文本中心的点
:param nodeTxt:
"""
# 当前树的大小
numLeafs = getNumleafs(myTree)
depth = getTreeDepth(myTree)
firstStr = list(myTree.keys())[0]
#cntrPt文本中心点 x:按比例确定; y:同按比例确定,从上到下的画
cntrPt = (plotTree.xOff + (1.0 + float(numLeafs))/2.0/plotTree.totalW, plotTree.yOff)
# 画分支上的键
plotMidText(cntrPt, parentPt, nodeTxt)
#绘制带箭头的注解--判断节点
plotNode(firstStr, cntrPt, parentPt, decisionnode)
secondDict = myTree[firstStr]
plotTree.yOff = plotTree.yOff - 1.0/plotTree.totalD #从上往下画
for key in secondDict.keys():
# 如果是字典则是一个判断结点,递归调用
if type(secondDict[key]) == dict:
plotTree(secondDict[key],cntrPt,str(key))
# 绘制叶节点
else:
plotTree.xOff = plotTree.xOff + 1.0/plotTree.totalW
# 绘制带箭头的注解
plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafnode)
# 画分支上的键
plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key))
plotTree.yOff = plotTree.yOff + 1.0/plotTree.totalD
def plotMidText(cntrPt, parentPt, txtString):
"""
在子父节点间填充文本信息
:param cntrPt: 文本中心的位置
:param parentPt: 指向文本中心的点的位置
:param txtString: 需要填写的文本(判断条件)
"""
xMid = (parentPt[0]-cntrPt[0])/2.0 + cntrPt[0]
yMid = (parentPt[1]-cntrPt[1])/2.0 + cntrPt[1]
createPlot.ax1.text(xMid, yMid, txtString, va="center", ha="center", rotation=30)
测试算法: 使用决策树执行分类
def classify(inputTree, featLabels, testVec):
"""
依据testVec 判断分类
:param inputTree: 用createTree生成的树
:param featLabels: 特征标签
:param testVec: 到达叶的路径
:return: 返回经过testVec值,所到达的叶节点的分类标签
"""
firstStr = list(inputTree.keys())[0]
secondDict = inputTree[firstStr]
featIndex = featLabels.index(firstStr)
for key in secondDict.keys():
if testVec[featIndex] == key:
if type(secondDict[key]) == dict:
classLabel = classify(secondDict[key], featLabels, testVec)
else:
classLabel = secondDict[key]
return classLabel
# 测试
mytree = list_tree[0]
Labels = ['no surfacing', 'flippers']
classify(mytree, Labels,[1,0])
使用算法:决策时的存储
def storeTree(inputTree, filename):
# 用pickle模块序列化对象:以二进制形式写入
import pickle
fw = open(filename,'wb')
pickle.dump(inputTree, fw)
fw.close()
def grabTree(filename):
# 通过pickle模块以二进制形式读取
import pickle
fr = open(filename,'rb')
out = pickle.load(fr)
fr.close()
return out
fil = r'C:\Users\dell\Desktop\1.txt'
storeTree(mytree, fil)
grabTree(fil)
5、使用决策树预测隐形眼镜类型
1)收集数据:提供的文本文件
2)准备数据:解析tab分隔的数据行
3)分析数据:快速检查数据,确保正确地解析数据的内容,使用createPlot()函数绘制最终的树形图
4)训练算法:使用createTree()函数
5)测试算法:编写测试函数验证决策树可以正确分类给定的数据实例
6)使用算法:存储树的数据结构,以便下次使用
# 读取数据
f = 'lenses.txt'
fr = open(f)
lenses = [lines.strip().split('\t') for lines in fr.readlines() ] # array
lensesLabels = ['age', 'prescript', 'astigmatic', 'tearRate'] # list
# 创建决策树
lensesTree = createTree(lenses, lensesLabels)
createPlot(lensesTree)
# 存储决策树
fil = r'C:\Users\dell\Desktop\lenses.txt'
storeTree(lensesTree, fil)
grabTree(fil)
还有不足,是ID3算法的不足及剪枝问题:
ID3算法的缺点:
1. 偏向于选择多值属性
2. 本身并未给出处理连续数据的方法
3. 算法不能处理带有缺失值的数据集
参考:《机器学习实战》
数据:链接:https://pan.baidu.com/s/17oKTrdJY_swA0Za8cvmtqw 密码:4rlk