绘制决策树的图片可以使用sklearn.tree.plot_tree
这个方法
详情可以参考官方文档:https://scikit-learn.org/stable/modules/generated/sklearn.tree.plot_tree.html
示例代码
import matplotlib.pyplot as plt
import numpy as np
from sklearn.tree import plot_tree # 树图
from sklearn.tree import DecisionTreeClassifier
if __name__ == '__main__':
# 准备数据
x_data = np.array([
[1, 2, 3, 4, 5, 6],
[2, 2, 3, 4, 5, 6],
]).T
y_data = np.array([6, 5, 4, 3, 4, 5])
# 训练一个树模型
dec_tree = DecisionTreeClassifier(
criterion='entropy', # “信息熵”最小化准则划分
max_leaf_nodes=8, # 最大叶子节点数
min_samples_leaf=0.05) # 叶子节点样本数量最小占比
dec_tree.fit(x_data, y_data) # 训练决策树
# 开始绘图
plt.figure(figsize=(14, 12)) # 指定图片大小
plot_tree(dec_tree,
feature_names=["x1", "x2"],
filled=True,
rounded=True)
plt.show()
效果图: