机器学习之决策树原理和sklearn实践

1. 场景描述

时间:早上八点,地点:婚介所

‘闺女,我有给你找了个合适的对象,今天要不要见一面?’

‘多大?’ ‘26岁’

‘长的帅吗?’ ‘还可以,不算太帅’

‘工资高吗?’ ‘略高于平均水平’

‘会写代码吗?’ ‘人家是程序员,代码写的棒着呢!’

‘好,把他的联系方式发过来吧,我抽空见一面’

上面的场景描述摘抄自 <百面机器学习>,是一个典型的决策树分类问题,通过年龄、长相、工资、是否会编程等特征属性对介绍对象进行是否约会进行分类

决策树是一种自上而下,对样本数据进行树形分类的过程,由结点和有向边组成,每个结点(叶结点除外)便是一个特征或属性,叶结点表示类别。从顶部根结点开始,所有样本聚在仪器,经过根结点的划分,样本被分到不同的子结点中。再根据子结点的特征进一步划分,直至样本都被分到某一类别(叶子结点)中

2. 决策树原理

决策树作为最基础、最常见的有监督学习模型,常被用于分类问题和回归问题,将决策树应用集成思想可以得到随机森林、梯度提升决策树等模型。其主要优点是模型具有可读性,分类速度快。决策树的学习通常包括三个步骤:特征选择、决策树的生成和决策树的修剪,下面对特征选择算法进行描述和区别

2.1 ID3---最大信息增益

在信息论与概率统计中,熵(entropy)是表示随机变量不确定性的度量,设X是一个取有限个值的随机变量,其概率分布为:\[P(X=X_i)=P_i (i = 1,2,...,n)\],则随机变量X的熵定义为:\[H(X) = -\sum_{i=1}^np_i\log{p_i}\]表达式中的对数以2为底或以e为底,这时熵的单位分别称作bit或nat,从表达式可以看出X的熵与X的取值无关,所以X的熵也记作\(H(p)\),即\[H(p) = -\sum_{x=1}^np_i\log{p_i}\]熵取值越大,随机变量的不确定性越大

条件熵:

条件熵H(Y|X)表示在已知随机变量X的条件下,随机变量Y的不确定性,随机变量X给定的条件下随机变量Y的条件熵定义为X给定条件下Y的条件概率分布的熵对X的数学期望\[H(Y|X) = \sum_{i=1}^nP(X=X_i)H(Y|X=X_i)\]

信息增益:\[g(D,A) = H(D) - H(D|A)\]

import pandas as pd
data = {
        '年龄':['老','年轻','年轻','年轻','年轻'],
        '长相':['帅','一般','丑','一般','一般'],
        '工资':['高','中等','高','高','低'],
        '写代码':['不会','会','不会','会','不会'],
        '类别':['不见','见','不见','见','不见']}
frame = pd.DataFrame(data,index=['小A','小B','小C','小D','小L'])
print(frame)
    年龄  长相  工资 写代码  类别
小A   老   帅   高  不会  不见
小B  年轻  一般  中等   会   见
小C  年轻   丑   高  不会  不见
小D  年轻  一般   高   会   见
小L  年轻  一般   低  不会  不见
import math
print(math.log(3/5))
print('H(D):',-3/5 *math.log(3/5,2) - 2/5*math.log(2/5,2))
print('H(D|年龄)',1/5*math.log(1,2)+4/5*(-1/2*math.log(1/2,2)-1/2*math.log(1/2,2)))
print('以同样的方法计算H(D|长相),H(D|工资),H(D|写代码)')
print('H(D|长相)',0.551)
print('H(D|工资)',0.551)
print('H(D|写代码)',0)
-0.5108256237659907
H(D): 0.9709505944546686
H(D|年龄) 0.8
以同样的方法计算H(D|长相),H(D|工资),H(D|写代码)
H(D|长相) 0.551
H(D|工资) 0.551
H(D|写代码) 0

计算信息增益:g(D,写代码)=0.971最大,可以先按照写代码来拆分决策树

2.2 C4.5---最大信息增益比

以信息增益作为划分训练数据集的特征,存在偏向于选择取值较多的问题,使用信息增益比可以对对着问题进行校正,这是特征选择的另一标准
信息增益比定义为其信息增益g(D,A)与训练数据集D关于特征A的值的熵\(H_A(D)\)之比:\[g_R(D,A) = \frac{g(D,A)}{H_A(D)}\]

\[H_A(D) = -\sum_{i=1}^n\frac{|D_i|}{|D|}\log\frac{|D_i|}{|D|}\]

拿上面ID3的例子说明:
\[H_年龄(D) = -1/5*math.log(1/5,2)-4/5*math.log(4/5,2)\]

\[g_R(D,年龄) = H_{年龄}(D)/g(D,年龄) = 0.171/0.722 = 0.236 \]

2.3 CART----最大基尼指数(Gini)

Gini描述的是数据的纯度,与信息熵含义类似,分类问题中,假设有K个类,样本点数据第k类的概率为\(P_k\),则概率分布的基尼指数定义为:
\[Gini(p) = 1- \sum_{k=1}^Kp_k(1-p_k) = 1 - \sum_{k=1}^Kp_{k}^2\]
对于二分类问题,弱样本点属于第1个类的概率是p,则概率分布的基尼指数为\[Gini(p) = 2p(1-p)\],对于给定的样本几何D,其基尼指数为\[Gini(D) = 1 - \sum_{k=1}^K[\frac{|C_k|}{|D|}]^2\]注意这里\(C_k\)是D种属于第k类的样本子集,K是类的个数,如果样本几个D根据特征A是否取某一可能指a被分割成D1和D2两部分,则在特征A的条件下,集合D的基尼指数定义为\[Gini(D,A) = \frac{|D_1|}{|D|}Gini(D_1)+\frac{|D_2|}{|D|}Gini(D_2)\]
\[Gini(D|年龄=老)=1/5*(1-1)+4/5*[1-(1/2*1/2+1/2*1/2)] = 0.4\]

CART在每一次迭代种选择基尼指数最小的特征及其对应的切分点进行分类

2.4 ID3、C4.5与Gini的区别

2.4.1 从样本类型角度

从样本类型角度,ID3只能处理离散型变量,而C4.5和CART都处理连续性变量,C4.5处理连续性变量时,通过对数据排序之后找到类别不同的分割线作为切割点,根据切分点把连续型数学转换为bool型,从而将连续型变量转换多个取值区间的离散型变量。而对于CART,由于其构建时每次都会对特征进行二值划分,因此可以很好地适合连续性变量。

2.4.2 从应用角度

ID3和C4.5只适用于分类任务,而CART既可以用于分类也可以用于回归

2.4.3 从实现细节、优化等角度

ID3对样本特征缺失值比较敏感,而C4.5和CART可以对缺失值进行不同方式的处理,ID3和C4.5可以在每个结点熵产生出多叉分支,且每个特征在层级之间不会复用,而CART每个结点只会产生两个分支,因此会形成一颗二叉树,且每个特征可以被重复使用;ID3和C4.5通过剪枝来权衡树的准确性和泛化能力,而CART直接利用全部数据发现所有可能的树结构进行对比。

3. 决策树的剪枝

3.1 为什么要进行剪枝?

对决策树进行剪枝是为了防止过拟合

根据决策树生成算法通过训练数据集生成了复杂的决策树,导致对于测试数据集出现了过拟合现象,为了解决过拟合,就必须考虑决策树的复杂度,对决策树进行剪枝,剪掉一些枝叶,提升模型的泛化能力

决策树的剪枝通常由两种方法,预剪枝和后剪枝

3.2 预剪枝

预剪枝的核心思想是在树中结点进行扩展之前,先计算当前的划分是否能带来模型泛化能力的提升,如果不能,则不再继续生长子树。此时可能存在不同类别的样本同时存于结点中,按照多数投票的原则判断该结点所属类别。预剪枝对于何时停止决策树的生长有以下几种方法

  • (1)当树达到一定深度的时候,停止树的生长
  • (2)当叶结点数到达某个阈值的时候,停止树的生长
  • (3)当到达结点的样本数量少于某个阈值的时候,停止树的生长
  • (4)计算每次分裂对测试集的准确度提升,当小于某个阈值的时候,不再继续扩展

预剪枝思想直接,算法简单,效率高特点,适合解决大规模问题。但如何准确地估计何时停止树的生长,针对不同问题会有很大差别,需要一定的经验判断。且预剪枝存在一定的局限性,有欠拟合的风险

3.3 后剪枝

后剪枝的核心思想是让算法生成一颗完全生长的决策树,然后从底层向上计算是否剪枝。剪枝过程将子树删除,用一个叶结点代替,该结点的类别同样按照多数投票原则进行判断。同样地,后剪枝叶可以通过在测试集上的准确率进行判断,如果剪枝过后的准确率有所提升,则进行剪枝,后剪枝方法通常可以得到泛化能力更强的决策树,但时间开销更大

损失函数

\[C_a(T) = \sum_{t=1}^{|T|}N_tH_t(T) + a|T|\]

\(其中|T|为叶结点个数,N_t为结点t的样本个数,H_t(T)为结点t的信息熵,a|T|为惩罚项,a>=0\)

\[C_a(T) = \sum_{t=1}^{|T|}N_tH_t(T) + a|T| = -\sum_{t=1}^{|T|}\sum_{k=1}^KN_{tk}\log \frac{N_{tk}}{N_t} + a|T|\]

注意:上面的公式中是\(N_{tk}\log \frac{N_{tk}}{N_t}\),而不是\(\frac{N_{tk}}{N_t} \log \frac{N_{tk}}{N_t}\)

令:\[C_a(T) = C(T) + a|T|\]

\(C(T)\)表示模型对训练数据的预测误差,即模型与训练数据的拟合程度,|T|表示模型复杂度,参数a>=0控制两者的影响力,较大的a促使选择较简单的模型,较小的a促使选择复杂的模型,a=0意味着只考虑模型与训练数据的拟合程度,不考虑模型的复杂度

4. 使用sklearn库为卫星数据集训练并微调一个决策树

4.1 需求

  • a.使用make_moons(n_samples=10000,noise=0.4)生成一个卫星数据集
  • b.使用train_test_split()拆分训练集和测试集
  • c.使用交叉验证的网格搜索为DecisionTreeClassifier找到合适的超参数,提示:尝试max_leaf_nodes的多种值
  • d.使用超参数对整个训练集进行训练,并测量模型测试集上的性能

代码实现

from sklearn.datasets import make_moons
import numpy as np
import pandas as pd
dataset = make_moons(n_samples=10000,noise=0.4)
print(type(dataset))
print(dataset)
<class 'tuple'>
(array([[ 0.24834453, -0.11160162],
       [-0.34658051, -0.43774172],
       [-0.25009951, -0.80638312],
       ...,
       [ 2.3278198 ,  0.39007769],
       [-0.77964208,  0.68470383],
       [ 0.14500963,  1.35272533]]), array([1, 1, 1, ..., 1, 0, 0], dtype=int64))
dataset_array = np.array(dataset[0])
label_array = np.array(dataset[1])
print(dataset_array.shape,label_array.shape)
(10000, 2) (10000,)
# 拆分数据集
from sklearn.model_selection import train_test_split
x_train,x_test = train_test_split(dataset_array,test_size=0.2,random_state=42)
print(x_train.shape,x_test.shape)
y_train,y_test = train_test_split(label_array,test_size=0.2,random_state=42)
print(y_train.shape,y_test.shape)
(8000, 2) (2000, 2)
(8000,) (2000,)
# 使用交叉验证的网格搜索为DecisionTreeClassifier找到合适的超参数
from sklearn.tree import DecisionTreeClassifier
from sklearn.model_selection import GridSearchCV

decisionTree = DecisionTreeClassifier(criterion='gini')
param_grid = {'max_leaf_nodes': [i for i in range(2,10)]}
gridSearchCV = GridSearchCV(decisionTree,param_grid=param_grid,cv=3,verbose=2)
gridSearchCV.fit(x_train,y_train)
Fitting 3 folds for each of 8 candidates, totalling 24 fits
[CV] max_leaf_nodes=2 ................................................
[CV] ................................. max_leaf_nodes=2, total=   0.0s
[CV] max_leaf_nodes=2 ................................................
[CV] ................................. max_leaf_nodes=2, total=   0.0s
[CV] max_leaf_nodes=2 ................................................
[CV] ................................. max_leaf_nodes=2, total=   0.0s
[CV] max_leaf_nodes=3 ................................................
[CV] ................................. max_leaf_nodes=3, total=   0.0s
[CV] max_leaf_nodes=3 ................................................
[CV] ................................. max_leaf_nodes=3, total=   0.0s
[CV] max_leaf_nodes=3 ................................................
[CV] ................................. max_leaf_nodes=3, total=   0.0s
[CV] max_leaf_nodes=4 ................................................
[CV] ................................. max_leaf_nodes=4, total=   0.0s
[CV] max_leaf_nodes=4 ................................................
[CV] ................................. max_leaf_nodes=4, total=   0.0s
[CV] max_leaf_nodes=4 ................................................
[CV] ................................. max_leaf_nodes=4, total=   0.0s
[CV] max_leaf_nodes=5 ................................................
[CV] ................................. max_leaf_nodes=5, total=   0.0s
[CV] max_leaf_nodes=5 ................................................
[CV] ................................. max_leaf_nodes=5, total=   0.0s
[CV] max_leaf_nodes=5 ................................................
[CV] ................................. max_leaf_nodes=5, total=   0.0s
[CV] max_leaf_nodes=6 ................................................
[CV] ................................. max_leaf_nodes=6, total=   0.0s
[CV] max_leaf_nodes=6 ................................................
[CV] ................................. max_leaf_nodes=6, total=   0.0s
[CV] max_leaf_nodes=6 ................................................
[CV] ................................. max_leaf_nodes=6, total=   0.0s
[CV] max_leaf_nodes=7 ................................................
[CV] ................................. max_leaf_nodes=7, total=   0.0s
[CV] max_leaf_nodes=7 ................................................
[CV] ................................. max_leaf_nodes=7, total=   0.0s
[CV] max_leaf_nodes=7 ................................................
[CV] ................................. max_leaf_nodes=7, total=   0.0s
[CV] max_leaf_nodes=8 ................................................
[CV] ................................. max_leaf_nodes=8, total=   0.0s
[CV] max_leaf_nodes=8 ................................................
[CV] ................................. max_leaf_nodes=8, total=   0.0s
[CV] max_leaf_nodes=8 ................................................
[CV] ................................. max_leaf_nodes=8, total=   0.0s
[CV] max_leaf_nodes=9 ................................................
[CV] ................................. max_leaf_nodes=9, total=   0.0s
[CV] max_leaf_nodes=9 ................................................
[CV] ................................. max_leaf_nodes=9, total=   0.0s
[CV] max_leaf_nodes=9 ................................................
[CV] ................................. max_leaf_nodes=9, total=   0.0s


[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.
[Parallel(n_jobs=1)]: Done   1 out of   1 | elapsed:    0.0s remaining:    0.0s
[Parallel(n_jobs=1)]: Done  24 out of  24 | elapsed:    0.0s finished

GridSearchCV(cv=3, error_score='raise-deprecating',
       estimator=DecisionTreeClassifier(class_weight=None, criterion='gini', max_depth=None,
            max_features=None, max_leaf_nodes=None,
            min_impurity_decrease=0.0, min_impurity_split=None,
            min_samples_leaf=1, min_samples_split=2,
            min_weight_fraction_leaf=0.0, presort=False, random_state=None,
            splitter='best'),
       fit_params=None, iid='warn', n_jobs=None,
       param_grid={'max_leaf_nodes': [2, 3, 4, 5, 6, 7, 8, 9]},
       pre_dispatch='2*n_jobs', refit=True, return_train_score='warn',
       scoring=None, verbose=2)
print(gridSearchCV.best_params_)
decision_tree = gridSearchCV.best_estimator_
{'max_leaf_nodes': 4}
# 使用测试集对模型进行评估
from sklearn.metrics import accuracy_score
y_prab = gridSearchCV.predict(x_test)
print('accuracy_score:',accuracy_score(y_test,y_prab))
accuracy_score: 0.8455
# 可视化模型
from sklearn.tree import export_graphviz

export_graphviz(decision_tree,
               out_file='./tree.dot',
               rounded = True,
               filled = True)

生成tree.dot文件,然后使用dot命令\[dot -Tpng tree.dot -o decisontree_moons.png\]

5. 附录

5.1 sklearn.tree.DecisionTreeClassifier类说明

5.1.1 DecsisionTreeClassifier类参数说明

  • criterion: 特征选择方式,string,('gini' or 'entropy'),default='gini'
  • splitter: 每个结点的拆分策略,('best' or 'random'),string,default='best'
  • max_depth: int,default=None
  • min_samples_split: int,float,default=2,分割前所需的最小样本数
  • min_samples_leaf:
  • min_weight_fraction_leaf:
  • max_features:
  • random_state:
  • max_leaf_nodes:
  • min_impurity_decrease:
  • min_impurity_split:
  • class_weight:
  • presort: bool,default=False,对于小型数据集(几千个以内)设置presort=True通过对数据预处理来加快训练,但对于较大训练集而言,可能会减慢训练速度

5.1.2 DecisionTreeClassifier属性说明

  • classes_:
  • feature_importances_:
  • max_features_:
  • n_classes_:
  • n_features_:
  • n_outputs_:
  • tree_:

5.2 GridSearchCV类说明

5.2.1 GridSearchCV参数说明

  • estimator: 估算器,继承于BaseEstimator
  • param_grid: dict,键为参数名,值为该参数需要测试值选项
  • scoring: default=None
  • fit_params:
  • n_jobs: 设置要并行运行的作业数,取值为None或1,None表示1 job,1表示all processors,default=None
  • cv: 交叉验证的策略数,None或integer,None表示默认3-fold, integer指定“(分层)KFold”中的折叠数
  • verbose: 输出日志类型

5.2.2 GridSearchCV属性说明

  • cv_results_: dict of numpy(masked) ndarray
  • best_estimator_:
  • best_score_: Mean cross-validated score of the best_estimator
  • best_params_:
  • best_index_: int,The index (of the ``cv_results_`` arrays) which corresponds to the best candidate parameter setting
  • scorer_:
  • n_splits_: The number of cross-validation splits (folds/iterations)
  • refit_time: float

参考资料:

  • (1) <机器学习实战基于scikit-learn和tensorflow>
  • (2) <百面机器学习>
  • (3)李航 <统计学习方法>

猜你喜欢

转载自www.cnblogs.com/xiaobingqianrui/p/11072556.html