机器学习笔记 二十:在Iris数据集上实现决策树的可视化

写在前面

决策树是一种用于机器学习的监督算法。它使用一个二进制树形图(每个节点有两个孩子)为每个数据样本分配一个目标值,目标值呈现在树叶中。为了到达树叶,样本通过节点传播,从根节点开始。在每个节点中,决定它应该去哪个子孙节点。决定是根据所选样本的特征做出的。决策树学习是一个根据所选指标在每个内部树节点中寻找最佳规则的过程。这些都是老生常谈的问题了,希望大家简单了解一下即可。

Iris数据集

在这里插入图片描述
数据展示:

import sklearn.datasets as datasets
from sklearn.tree import DecisionTreeClassifier 
from sklearn import tree
from matplotlib import pyplot as plt
import pandas as pd

iris = datasets.load_iris()

X_df = pd.DataFrame(iris.data, columns = iris.feature_names)
print(X_df.head(15))

Y=iris.target
print("\nClass Labels for all the data points:\n", Y)

在这里插入图片描述

# 数据拟合
dtree = DecisionTreeClassifier()   # (random_state=1234)
model=dtree.fit(X_df,Y)

text_representation = tree.export_text(dtree)
print(text_representation)

在这里插入图片描述

# 结果保存
with open("iris_DecisionTree_text.txt", "w") as fout:
    fout.write(text_representation)

使用plot_tree绘制决策树

plot_tree方法在0.21版本中被添加到sklearn中,需要安装matplotlib。它允许我们很容易地产生树的图(不需要中间导出到graphviz)。

fig = plt.figure(figsize=(25,20))
_ = tree.plot_tree(dtree, 
                   feature_names=iris.feature_names,  
                   class_names=iris.target_names,
                   filled=True)

在这里插入图片描述plot_tree中使用fill=True:当这个参数被设置为True时,该方法使用颜色来表示大部分的类。(如果能有一些与类和颜色相匹配的图例就更好了)。

import six
import sys
sys.modules['sklearn.externals.six'] = six

from sklearn.externals.six import StringIO
from IPython.display import Image
from sklearn.tree import export_graphviz
import pydotplus
print("Import Successful")

# 绘制决策树
dot_data = StringIO()
export_graphviz(dtree, out_file=dot_data, 
                feature_names = iris.feature_names, 
                filled = True, rounded = True, 
                special_characters = True, node_ids = True)
graph=pydotplus.graph_from_dot_data(dot_data.getvalue())

Image(graph.create_png())

# 保存数据
graph.write_png("iris_DecisionTree_graphivz1.png")

在这里插入图片描述

graphviz绘制决策树

import graphviz

# DOT format data
dot_data = tree.export_graphviz(dtree, out_file=None, 
                                feature_names=iris.feature_names,  
                                class_names=iris.target_names,
                                filled=True)

# Draw Decision Tree
graph = graphviz.Source(dot_data, format="png")  # change "png" to "pdf" for PDF format
graph

# 保存数据
graph.render("iris_DecisionTree_graphivz2")

在这里插入图片描述

dtreeviz绘制决策树

from dtreeviz.trees import dtreeviz # remember to load the package

viz = dtreeviz(dtree, X_df, Y,
                target_name="target",
                feature_names=iris.feature_names,
                class_names=list(iris.target_names))

viz

viz.save("iris_DecisionTree_dtreeviz.svg")

在这里插入图片描述

猜你喜欢

转载自blog.csdn.net/amyniez/article/details/128398940