sklearn库学习之决策树

决策树

学习决策树,就是学习一系列if/else问题,在机器学习中,这些问题叫做测试,算法搜索所有可能的测试,找出对目标变量来说信息量最大的哪一个。
算法过程生成一棵二元决策树,其中每个结点都包含一个测试。将每个测试看成沿着一条轴对当前数据进行划分,由于每个测试仅仅关注一个特征,所以划分后的区域边界始终与坐标轴平行。反复划分,直到决策树的每个叶结点变成纯的。

查看新数据点位于特征空间划分的哪个区域,即基于每个结点的测试对树进行遍历,找到新数据点所属的叶结点。
分类:将该区域的多数目标值作为预测结果。
回归:叶结点中所有训练点的平均目标值。

防止过拟合的两种策略:

  1. 预剪枝:及早停止树的增长,限制树的最大深度、限制叶结点的最大数目、规定结点中数据点的最小数目。
  2. 后剪枝:先构造树,然后删除信息量很小的结点。

优点:

  1. 得到的模型容易可视化
  2. 算法完全不受数据缩放的影响,决策树算法不需要特征预处理

缺点:

  1. 经常会过拟合,泛化性能很差
#在乳腺癌数据集上查看预剪枝的效果
import os
os.environ["PATH"] += os.pathsep + r'the Graphviz bin file address on your system'
from sklearn.tree import DecisionTreeClassifier
from sklearn.datasets import load_breast_cancer
from sklearn.model_selection import train_test_split
import mglearn
import matplotlib.pyplot as plt
import numpy as np

cancer = load_breast_cancer()
X_train, X_test, y_train, y_test = train_test_split(cancer.data, cancer.target, stratify = cancer.target, random_state = 42)

#先使用默认设置构建模型
tree = DecisionTreeClassifier(random_state = 0)
tree.fit(X_train,y_train)

#未剪枝的树容易过拟合,对新数据的泛化性能不佳
print("Accuracy on training set: {:.3f}".format(tree.score(X_train,y_train)))
print("Accuracy on test set: {:.3f}".format(tree.score(X_test,y_test)))

#预剪枝
tree = DecisionTreeClassifier(max_depth = 4, random_state = 0)
tree.fit(X_train,y_train)
print("Accuracy on training set: {:.3f}".format(tree.score(X_train,y_train)))
print("Accuracy on test set: {:.3f}".format(tree.score(X_test,y_test)))

#分析决策树,将树可视化
from sklearn.tree import export_graphviz #生成.dot格式文件
export_graphviz(tree, out_file = "tree.dot",class_names = ['malignant','benign'],feature_names = cancer.feature_names,
               impurity = False, filled = True) #为结点添加颜色,传入类别名称和特征名称

#利用graphviz模块读取.dot文件并可视化
#观察树,找出大部分数据的实际路径
import graphviz
with open("tree.dot") as f:
    dot_graph = f.read()
graphviz.Source(dot_graph)


#树的特征重要性
print("Feature importance:\n{}".format(tree.feature_importances_))

#将特征重要性可视化
def plot_feature_importances_cancer(model):
    n_features = cancer.data.shape[1]
    plt.barh(range(n_features),model.feature_importances_,align = 'center')
    plt.yticks(np.arange(n_features),cancer.feature_names)
    plt.xlabel("Feature importance")
    plt.ylabel("Feature")

#如果某特征的重要性较小,说明该特征没有被树选中,或另一个特征也包含了相同的信息
plot_feature_importances_cancer(tree)

DecisionTreeRegressor以及其他所有基于树的回归模型不能外推,也不能在训练数据范围之外进行预测。

对代码中的疑惑

Python sklearn库中决策树tree.DecisionTreeClassifier()函数参数介绍
https://blog.csdn.net/li980828298/article/details/51172744

Matplotlib快速绘图
https://www.cnblogs.com/peihao/p/5290075.html

猜你喜欢

转载自blog.csdn.net/thj19980720/article/details/83177351