《统计学习方法》系列(5)

版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/u012294618/article/details/79835282

  本篇对应全书第五章,讲的是决策树。决策树(decision tree)是一种基本的分类与回归方法。决策树模型呈树形结构,在分类问题中,表示基于特征对实例进行分类的过程。决策树学习通常包括3个步骤:特征选择、决策树的生成和决策树的修剪。决策树学习常用的算法有ID3、C4.5和CART。


1、理论讲解

  ID3和C4.5生成的决策树只能用于分类问题,而CART生成的决策树既可用于分类问题也可用于回归问题,因此本文主要讨论用于分类的决策树,只在CART部分讲述用于回归的决策树。

1.1、基础知识

  决策树由结点(node)和有向边(directed edge)组成。结点有两种类型:内部结点(internal node)和叶结点(leaf node)。内部结点表示一个特征或属性,叶结点表示一个类。
  用决策树分类,从根结点开始,对实例的某一特征进行测试,根据测试结果,将实例分配到其子结点:这时,每一个子结点对应着该特征的一个取值。如此递归地对实例进行测试并分配,直至达到叶结点。最后将实例分到叶结点的类中。下面是一个决策树的示意图,矩形框和椭圆框分别表示内部结点和叶结点。

![在这里插入图片描述](https://img-blog.csdn.net/20180922211138224)
  决策树的学习过程就是决策树的生成过程,决策树的生成基本遵循以下流程:

输入:训练集D,特征集A;
输出:决策树T。
(1)若D中所有实例属于同一类 C k C_k ,则T为单结点树,并将类 C k C_k 作为该结点的类标记,返回T;
(2)若 A = A=\varnothing ,或D中样本在A上取值相同,则T为单结点树,并将D中实例数最大的类 C k C_k 作为该结点的类标记,返回T;
(3)从A中选取最优划分特征 A A^* ,对 A A^* 的每一可能值 a i a_i ,依 A = a i A^*=a_i 将D分割为若干非空子集 D i D_i ,将 D i D_i 中实例数最大的类作为标记,构建子结点,由结点及其子结点构成树T,返回T;
(4)对第i个子结点,以 D i D_i 为训练集,以 A { A } A-\{A^*\} 为特征集,递归地调用步(1)~步(3),得到子树 T i T_i ,返回 T i T_i

  上述流程中,有几个关键问题:1).如何选取最优划分特征;2).数据怎么分割;3).决策树停止分裂的条件。
  第一个问题,实际上就是“特征选择”,“特征选择”的目的在于选取对训练数据具有分类能力的特征,这样可以提高决策树学习的效率;如果利用一个特征进行分类的结果与随机分类的结果没有很大差别,则称这个特征是没有分类能力的,经验上扔掉这样的特征对决策树学习的精度影响不大;“特征选择”的关键是其准则,不同的算法会有所不同,这一点后面我们还会详细解释,不过它们的本质都在于“选择使得所有子结点数据最纯的特征”。
  第二个问题,最优划分特征 A A^* ,其数据属性分离散和连续两种,对于离散型数据,可以按照特征取值进行数据分割,每一个取值对应一个分裂点;对于连续型数据,不同的算法有不同的处理方式,后面我们再解释。
  第三个问题,决策树不能无限制分裂。一方面复杂度太高,易过拟合,泛化能力差;另一方面,分裂到后期容易受噪声数据影响。停止分裂的条件一般有这么几个:a).最小结点数:当结点数据量小于阈值时,停止分裂;b).熵或者基尼值小于阈值:熵和基尼值的大小表示数据的复杂程度,当熵或者基尼值过小时,表示数据的纯度比较大;c).决策树的深度达到指定条件;d).特征集为空。

  决策树生成的过程中,为了尽可能正确分类训练集,结点进行了多次分裂,有时会造成决策树分支过多而过拟合的情况。我们需要对已生成的树进行剪枝,将树变得更简单,从而使它具有更好的泛化能力。
  常见的剪枝策略有预剪枝(pre-pruning)和后剪枝(post-pruning)。
  预剪枝,即在完全正确分类训练集之前,较早地停止树的生长,前面所讲的决策树停止分裂的条件即可看作是预剪枝。一种更普遍的做法是,根据验证集,计算每次分裂对决策树性能的增益,如果这个增益值小于某个阈值则不进行分裂。预剪枝的优点在于算法简单,效率高;缺点在于其基于“贪心”的本质,过早地停止决策树的生长,可能带来欠拟合的风险。
  后剪枝,即在已生成的过拟合决策树上进行剪枝。主要的后剪枝策略有:REP(Reduced-Error Pruning,错误率降低剪枝)、PEP(Pessimistic Error Pruning,悲观剪枝)、MEP(Minimum Error Pruning,最小错误率剪枝)、CCP(Cost-Complexity Pruning,代价复杂度剪枝)。这里我们不赘述每种剪枝策略的细节,感兴趣的读者可阅读[11][12]

1.2、ID3、C4.5和CART

  1984年,Breiman等人提出CART;1986年,Quinlan提出ID3;1993年,Quinlan提出C4.5。ID3和C4.5是分类树,C4.5是对ID3的优化,CART是分类和回归树。
  由于关于此三种算法的介绍很多,因此本节会省略不必要的细节,只保留重要的结论性的内容讲述。

1.2.1、ID3

特性:

  最优特征选取准则:信息增益(information gain);
  数据分割:离散型数据,按取值进行分割;连续型数据,无法处理,只能将连续型数据离散化(如等距离数据划分)后再做处理;
  数据缺省值处理:无;
  剪枝:无。

缺点:

  1. 以信息增益作为特征选取准则,存在偏向于选取取值较多的特征的问题;
  2. 无法处理连续数据和数据缺省值;
  3. 未剪枝,可能过拟合。

1.2.2、C4.5

特性:

  最优特征选取准则:信息增益比(information gain ratio),这里需要注意的是,信息增益比准则对可取值数目较少的特征有所偏好,因此C4.5采用的是一个启发式的方法:先从候选特征中找出信息增益高于平均水平的特征,再从中选择信息增益比最高的;
  数据分割:离散型数据,同ID3;连续型数据,对于特征a的n个取值,将它们从小到大排序,求得n个取值的n-1个数据中点,对每一个数据中点 t i t_i ,根据特征取值大于或小于 t i t_i 对数据进行二分,计算特征a在每个数据中点 t i t_i 的信息增益,选取信息增益最大的 t i t_i 作为特征a的分裂点,再在不同特征之间根据信息增益比选择最优划分特征(与离散特征不同的是,如果当前结点为连续特征,则该特征后面还可参与子结点的产生选择过程);
  数据缺省值处理:有,具体请参考《机器学习》(周志华 著);
  剪枝:PEP。

缺点:

  1. C4.5生成的是多叉树,很多时候,在计算机中二叉树模型会比多叉树计算效率高;
  2. C4.5只能用于分类;
  3. C4.5使用了熵模型,其中有大量耗时的对数运算,如果是连续值还有大量的排序运算。

1.2.3、CART

  CART(Classification And Regression Tree,分类和回归树)既可用于分类也可用于回归。CART假设决策树是二叉树,递归地二分每个特征。

特性:

  最优特征选取准则:最小方差(回归树),基尼指数(分类树);
  数据分割:离散型数据,将其中一个离散值独立作为一个节点,其他的离散值生成另外一个节点,有多少个离散值就有多少种可供选择的划分方式(分裂后,若某一分支上的被选取特征取值多于1个,被选取特征依然可参与到后续结点的产生选择过程);连续型数据,同C4.5;
  数据缺省值处理:同C4.5;
  剪枝:CCP。

1.3、其它

  1. 从所有可能的决策树中选取最优决策树是NP完全问题,所以现实中决策树学习算法通常采用启发式方法,近似求解这一最优化问题,这样得到的决策树是次最优(sub-optimal)的;
  2. 决策树的生成对应于模型的局部选择,决策树的剪枝对应于模型的全局选择;
  3. 如果特征数量很多,也可以在决策树学习开始的时候,对特征进行选择,只留下对训练数据有足够分类能力的特征;
  4. 决策树的剪枝往往通过极小化决策树整体的损失函数来实现;

2、代码实现

2.1、sklearn实现

from __future__ import division
import numpy as np
import graphviz
import matplotlib.pyplot as plt
from sklearn.datasets import load_iris
from sklearn.tree import DecisionTreeClassifier
from sklearn import tree
from sklearn.model_selection import train_test_split


if __name__ == "__main__":
        iris = load_iris()
        X = iris.data
        y = iris.target

        X_train, X_test, y_train, y_test = train_test_split(X, y, test_size = 0.2, random_state = 0)

        clf = DecisionTreeClassifier()
        clf.fit(X_train, y_train)
        score = clf.score(X_test, y_test)
        print "score : %s" % score

        dot_data = tree.export_graphviz(clf, out_file=None,
                         feature_names=iris.feature_names,
                         class_names=iris.target_names,
                         filled=True, rounded=True,
                         special_characters=True)
        graph = graphviz.Source(dot_data)
        graph.render("iris")

代码已上传至github:https://github.com/xiongzwfire/statistical-learning-method


#参考文献
[1] 《机器学习》(周志华 著)
[2] http://www.cnblogs.com/yonghao/p/5061873.html
[3] http://www.cnblogs.com/yonghao/p/5064996.html
[4] http://www.cnblogs.com/yonghao/p/5096358.html
[5] http://www.cnblogs.com/yonghao/p/5122703.html
[6] https://www.cnblogs.com/yonghao/p/5135386.html
[7] http://www.cnblogs.com/pinard/p/6050306.html
[8] http://www.cnblogs.com/pinard/p/6053344.html
[9] http://www.cnblogs.com/pinard/p/6056319.html
[10] https://www.zhihu.com/question/22697086
[11] https://zhuanlan.zhihu.com/p/30296061
[12] http://blog.sina.com.cn/s/blog_4e4dec6c0101fdz6.html
[13] https://blog.csdn.net/jiede1/article/details/76034328
[14] https://blog.csdn.net/baimafujinji/article/details/53269040
[15] https://www.zhihu.com/question/27205203?sort=created
[16] https://blog.csdn.net/mao_xiao_feng/article/details/52728164
[17] https://www.cnblogs.com/pinard/p/6140514.html
[18] https://zhuanlan.zhihu.com/p/30296061
[19] http://blog.sina.com.cn/s/blog_4e4dec6c0101fdz6.html
以上为本文的全部参考文献,对原作者表示感谢。

猜你喜欢

转载自blog.csdn.net/u012294618/article/details/79835282