版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/potato012345/article/details/52734688
决策树
算法基本思想
决策树算法是一种基于树形结构的分类算法。通常一棵决策树包括一个根节点、若干内部节点和若干叶子节点。其中,根节点和内部节点代表分类所要依据的一系列属性,叶子节点则代表具体地类别。
决策树进行分类时,样例数据从根节点开始进行“属性测试”。该样例在节点对应的属性上的取值决定样例在树的哪一个分支上进行下一次属性测试,从上而下,直到叶子节点,此时样例所在叶子节点对应的类别就是对样例的判断结果。
决策树的构建算法
输入:训练集D和属性集A
输出:构造成功的决策树
函数:createTree(D,A)
createTree(D,A)
{
生成一个节点node
if D 中所有样例都属于同一个类别C then:
将 node 的类别置为C
return node
if A 为空 或 D 中所有样例在 A 上有相同的取值 then:
将 node 的类别置为 D 中占多数的类别
return node
从A中选择一个最优划分属性 a
for a 的每一个属性值 i do:
为node创建一个分支
D = {D | D(a)=i }
A = A - a
node的分支 = createTree(D,A)
return node
}
决策树以node 为根节点
最优划分属性的选取
显然,最优划分属性的选取是决策树构造过程中最关键的一步。那么我们从什么样的标准出发来选择最优的划分属性呢?
一般而言,我们希望数据集在经过节点的划分之后能够有更高的“纯度”,即样例尽可能地属于同一类类别。为了度量样例集的“信息纯度”,我们引入信息熵的概念。
信息熵定义如下:
进而,我们把经过节点的划分后的各个子集的信息熵之和相对于原集的变化程度作为选取划分属性的标准,称为信息增益。
信息增益定义如下:
能够使样例集的信息增益最大的属性即被选为最优划分属性。
信息增益只是众多选取指标中的一种,最早应用在ID3算法中。类似地标准还有信息增益率和基尼指数。
算法实现 Python 源代码
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Author : Zhangtd
# decision tree in MLinAction
from math import log
import operator
def calcShannonEnt(dataset): #计算数据集信息熵(ID3算法)
numEntries = len(dataset)
labelCounts = {}
for featVec in dataset: # 利用字典来统计各个类别出现的次数
currentLabel = featVec[-1]
if currentLabel not in labelCounts.keys(): # d.keys() 取出字典d的键组成列表
labelCounts[currentLabel] = 0
labelCounts[currentLabel] += 1
shannonEnt = 0.0
for key in labelCounts:
prob = float(labelCounts[key]) / numEntries
shannonEnt -= prob * log(prob, 2)
return shannonEnt
def createDataSet(): #读入数据集
fr = open('lenses.txt', 'r')
dataset = [line.strip().split('\t') for line in fr.readlines()]
# line.strip().split('\t') 将文件中的每一行的首尾部分的空白符删去,
# 然后以\t为标志将其划分存储,返回一个列表
labels = ['age', 'prescript', 'astigmatic', 'tearRate']
return dataset, labels
def splitDataSet(dataset, axis, val): #根据指定的属性及属性值对数据集进行删减,返回子集
retDataSet = []
for featVec in dataset:
if featVec[axis] == val:
reducedFeatVec = featVec[:axis]
reducedFeatVec.extend(featVec[axis + 1:]) #注意列表方法extend()与append()的的区别
retDataSet.append(reducedFeatVec)
return retDataSet
def chooseBestFeature(dataset): #选取使信息增益最大的属性值
numFeature = len(dataset[0]) - 1
baseEnt = calcShannonEnt(dataset)
bestInfoGain = 0.0
bestFeature = -1
for i in range(numFeature):
featList = [example[i] for example in dataset] #列表推导式,用于复制列表
uniqueVal = set(featList) #set() 方法可以将对象转化为集合,自动删并重复值
newEnt = 0.0
for value in uniqueVal:
subDataset = splitDataSet(dataset, i, value)
prob = len(subDataset) / float(len(dataset))
newEnt += prob * calcShannonEnt(subDataset)
InfoGain = baseEnt - newEnt
if(InfoGain > bestInfoGain):
bestInfoGain = InfoGain
bestFeature = i
return bestFeature
def majorityCnt(classList): #出现第二种判断条件时用于确认多数类别的函数
classCount = {}
for vote in classList:
if vote not in classCount.keys():
classCount[vote] = 0
classCount[vote] += 1
sortedClassCount = sorted(classCount.iteritems(),
key=operator.itemgetter(1), reverse=True)
#operator模块的sorted函数,对字典进行排序
Tclass = sortedClassCount[0][0]
return Tclass
def createTree(dataset, lables): #构建决策树的主函数
classList = [example[-1] for example in dataset]
if(classList.count(classList[0]) == len(classList)):
return classList[0]
if(len(dataset[0]) == 1):
return majorityCnt(classList) #递归终止的两个判定条件
bestFeature = chooseBestFeature(dataset)
bestFeatureLable = lables[bestFeature]
myTree = {bestFeatureLable: {}}
del(lables[bestFeature])
featValues = [example[bestFeature] for example in dataset]
uniqueVal = set(featValues)
for val in uniqueVal:
subLables = lables[:]
subDataset = splitDataSet(dataset, bestFeature, val)
myTree[bestFeatureLable][val] = createTree(subDataset, subLables)
return myTree
def storeTree(filename, inputTree): #pickle模块将字典类型的决策树直接存储为文本
import pickle
fw = open(filename, 'w')
pickle.dump(inputTree, fw)
fw.close()
def loadTree(filename): #pickle模块读取文本中的字典
import pickle
fr = open(filename, 'r')
return pickle.load(fr)
def classify(inputTree, featLables, testVec): #读取决策树作为分类器用于数据分类
firstStr = inputTree.keys()[0]
secondDict = inputTree[firstStr]
featIndex = featLables.index(firstStr)
for key in secondDict.keys():
if testVec[featIndex] == key:
if type(secondDict[key]).__name__ == 'dict': #type().__name__ 取出对象的数据类型的名称
classLable = classify(secondDict[key], featLables, testVec)
else:
classLable = secondDict[key]
return classLable
#myData, lables = createDataSet()
#Lables = [Str for Str in lables]
#myTree = createTree(myData, Lables)
fName = 'myTree.txt'
#storeTree(fName , myTree)
myTree = loadTree(fName)
print myTree