决策树模型介绍

1、决策树介绍

决策树是一种树状结构,它的每一个叶节点对应着一个分类,非叶节点对应着在某个属性上的划分,根据样本在该属性不同取值将其划分为若干个子集。

决策树构造的核心问题:在每一步如何选择适当的属性对样本做拆分。

决策树处理过程:对分类问题,应从已知类标记的训练样本中学习并构造出决策树,自上而下,分开进行解决。

2、常用决策树算法

决策树算法 算法描述
ID3算法 核心:在决策树的各级节点上,使用信息增益方法作为属性的选择标准,来帮助确定生成每个节点时所应采用的合适属性。
C4.5算法 C4.5决策树生成算法相对于ID3算法的重要改进:使用信息增益率来选择节点属性。C4.5可客服ID3算法的不足:ID3算法只适用于离散的描述属性,而C4.5算法既能处理离散的描述属性,也可处理连续的描述属性
CART算法 CART决策树是一种非参数分类和回归方法,通过构建树、修剪树、评估树来构造一个二叉树。当终结点是连续变量时,该树为回归树;当终结点是分类变量,该树为分类树

3、ID3算法

ID3算法原理

ID3算法是基于信息熵来选择最佳测试属性

它选择当前样本集中具有最大信息增益值的属性作为测试属性

样本集的划分则依据测试属性的取值进行,测试属性有多少不同取值就将样本集划分为多少子样本集,同时决策树上相应于该样本集的节点长出新的叶子节点。

ID3算法根据信息论理论,采用划分样本集的不确定性作为衡量划分好坏的标准,用信息增益值度量不确定性,信息增益值越大,不确定性越小。

ID3算法在每个非叶节点选择信息增益最大的属性作为测试属性,这样可以得到当前情况下最纯的拆分,从而得到较小的决策树。

设S是s个数据样本的集合,假设类别属性具有m个不同值,Ci(i = 1,2,....,m),设si是类Ci中的样本数。对一个给定的样本,它总的信息熵为:

 其中,Pi是任意样本属于Ci的概率,一般可以用si/s估计

设一个属性A具有k个不同的值{a1,a2,...ak},利用属性A将集合S划分为个子集{S1,S2,...,Sk},其中Sj包含了集合S中属性A取aj值的样本。若选择属性A为测试属性,则这些子集就是从集合S的节点生长出来的新的叶节点

。设Sij是子集Sj中类别为Ci的样本数,则根据属性A划分样本的信息熵值为:

其中,

是子集Sj中类别为Ci的样本的概率。最后,用属性A划分样本集S后所得到的信息增益为:Gain(A)=I(s1.s2....,sm)-E(A)

E(A)越小,Gain(A)的值越大,说明选择测试属性A对于分类提供的信息越大,选择A后对分类的不确定成都越小,属性A的k个不同的值对应样本集S的k个子集或分支,通过递归调用上述过程,生成其他属性作为节点的子节点和分支来生成整个决策树。ID3决策树算法作为作为一个典型的决策树学习算法,其核心是在决策树的各级节点上用信息增益作为判断标准进行属性的选择,使得每个非叶节点上进行测试时,都能获得最大的类别分类增益,使分类后数据集的熵最小。这种处理方法使树的平均深度较小,从而有效提高分类效率。

ID3算法流程

(1)对当前样本集合,计算所有属性的信息增益;

(2)选择信息增益最大的属性作为测试属性,把测试属性取值相同的样本划分为同一个子样本集;

(3)若子样本集的类别属性只含有单个属性,则分支为叶子节点,判断其属性值并标上相应的符号,然后返回调用处;否则对子样本集递归调用本算法。

#-*- coding: utf-8 -*-
#使用ID3决策树算法预测销量高低
import pandas as pd

#参数初始化
inputfile = '../data/sales_data.xls'
data = pd.read_excel(inputfile, index_col = u'序号') #导入数据

#数据是类别标签,要将它转换为数据
#用1来表示“好”、“是”、“高”这三个属性,用-1来表示“坏”、“否”、“低”
data[data == u''] = 1
data[data == u''] = 1
data[data == u''] = 1
data[data != 1] = -1
x = data.iloc[:,:3].as_matrix().astype(int)
y = data.iloc[:,3].as_matrix().astype(int)

from sklearn.tree import DecisionTreeClassifier as DTC
dtc = DTC(criterion='entropy') #建立决策树模型,基于信息熵
dtc.fit(x, y) #训练模型

#导入相关函数,可视化决策树。
#导出的结果是一个dot文件,需要安装Graphviz才能将它转换为pdf或png等格式。
from sklearn.tree import export_graphviz
x = pd.DataFrame(x)
from sklearn.externals.six import StringIO
x = pd.DataFrame(x)
with open("tree.dot", 'w') as f:
  f = export_graphviz(dtc, feature_names = x.columns, out_file = f)

猜你喜欢

转载自www.cnblogs.com/Iceredtea/p/12056794.html