Python机器学习【二】 - 决策树

Python机器学习【二】 - 决策树

原文地址:Python机器学习【二】 - 决策树

上一篇基于sklearn Python库创建K近邻模型(KNN)实现了机器学习Hello World示例,KNN属于最简单的分类算法,简单来讲就是当预测一个新的值时,根据它最近的K个点是什么类型来判断属于哪个类别。掌握了KNN也算是简单入门机器学习了,学习新东西总是要由浅入深,这次准备学习另一种分类算法:决策树

一、决策树

决策树算法采用树形结构,使用层层推理来实现最终的分类,是一种基于 if-then-else 规则的有监督学习算法,决策树的这些规则通过训练得到,而不是人工制定的。

有监督学习: 给定数据,预测标签,它从有标记的训练数据中推导出预测函数。

无监督学习: 给定数据,寻找隐藏的结构,它从无标记的训练数据中推断结论。

我们的生活中有很多决策树的案例,举个简单栗子:银行要对客户进行信用登记评估,可能会根据有无逾期记录、有无逾期未还款记录、有无房产等指标进行评估,得到下面最简单的模型:

二、目标

已知有178个葡萄酒样本数据,每个样本采样酒精、苹果酸、灰、灰的碱性、镁、总酚等共13个特征,葡萄酒分为三个类型,分别用class_0、class_1、class_2表示。

178个葡萄酒中:class_0有59个,class_1有71个,class_2有48个,我们仍然取178条数据的70%作为训练数据,经过训练模型、模型评估,预测剩余30%数据样本的类型,计算出准确率以及查看学习生成的决策树。

三、实现步骤
  1. 安装sklearn
pip install -U scikit-learn
  1. Python实现
  • 获取数据
from sklearn import datasets


# 获取葡萄酒据集,sklearn库提供
wines = datasets.load_wine()
"""
共178个样本,样本具有13个特征值(feature_names)
feature_names:
    alcohol: 酒精
    malic_acid: 苹果酸
    ash: 灰
    alcalinity_of_ash: 灰的碱性
    magnesium: 镁
    total_phenols: 总酚
    flavanoids: 类黄酮
    nonflavanoid_phenols: 非黄烷类酚类
    proanthocyanins: 花青素
    color_intensity: 颜色强度
    hue: 色调
    od280/od315_of_diluted_wines: od280/od315稀释葡萄酒
    proline: 脯氨酸
三种红酒类型(classes)
classes:
    class_0
    class_1
    class_2
"""
print("葡萄酒类型", list(wines.target_names))  # ['class_0', 'class_1', 'class_2']
print("样本特征值", list(wines.feature_names))  # 太长,见注释
print("数据规模", wines.data.shape)  # (178, 13)
  • 数据预处理
x_train, x_test, y_train, y_test = train_test_split(wines.data, wines.target, test_size=0.3)  # 分训练集、测试集  测试集占0.3
print("训练集", x_train.shape)  # (124, 13)
print("测试集", x_test.shape)  # (54, 13)
  • 训练模型
clf = tree.DecisionTreeClassifier(criterion="entropy")  # 载入决策树分类模型
clf = clf.fit(x_train, y_train)
  • 模型评估
score = clf.score(x_test, y_test)
print("准确度", score)  # 0.9629629629629629
  • 画出该决策树

graphviz是一个绘图工具,可以根据dot脚本画出树形图,windows平台使用graphviz需要先安装graphviz软件,再安装graphviz python模块

  1. 安装graphviz软件

下载地址:graphviz软件
注意安装时勾选 Add Graphviz to the system PATH,安装成功后,需要重启 PyCharm

  1. 安装graphviz模块
pip install graphviz
  1. 画出决策树
import graphviz

feature_name = ['酒精', '苹果酸', '灰', '灰的碱性', '镁', '总酚', '类黄酮', '非黄烷类酚类', '花青素', '颜色强度', '色调', 'od280/od315稀释葡萄酒', '脯氨酸']
class_names = ["class_0", "class_1", "class_2"]
dot_data = tree.export_graphviz(clf, out_file='tree.dot', feature_names=feature_name, class_names=class_names, filled=True, rounded=True)
# 处理中文乱码
with open("tree.dot", encoding='utf-8') as f:
    dot_graph = f.read()
    graph = graphviz.Source(dot_graph.replace("helvetica", "FangSong"))
graph.view()
完整代码
from sklearn import datasets
from sklearn.model_selection import train_test_split
from sklearn import tree
import graphviz

# 获取葡萄酒据集,sklearn库提供
wines = datasets.load_wine()
"""
共178个样本,样本具有13个特征值(feature_names)
feature_names:
    alcohol: 酒精
    malic_acid: 苹果酸
    ash: 灰
    alcalinity_of_ash: 灰的碱性
    magnesium: 镁
    total_phenols: 总酚
    flavanoids: 类黄酮
    nonflavanoid_phenols: 非黄烷类酚类
    proanthocyanins: 花青素
    color_intensity: 颜色强度
    hue: 色调
    od280/od315_of_diluted_wines: od280/od315稀释葡萄酒
    proline: 脯氨酸
三种红酒类型(classes)
classes:
    class_0
    class_1
    class_2
"""
print("葡萄酒类型", list(wines.target_names))  # ['class_0', 'class_1', 'class_2']
print("样本特征值", list(wines.feature_names))  # 太长,见注释
print("数据规模", wines.data.shape)  # (178, 13)

x_train, x_test, y_train, y_test = train_test_split(wines.data, wines.target, test_size=0.3)  # 分训练集、测试集  测试集占0.3
print("训练集", x_train.shape)
print("测试集", x_test.shape)

clf = tree.DecisionTreeClassifier(criterion="entropy")  # 载入决策树分类模型
clf = clf.fit(x_train, y_train)
score = clf.score(x_test, y_test)
print("准确度", score)

feature_name = ['酒精', '苹果酸', '灰', '灰的碱性', '镁', '总酚', '类黄酮', '非黄烷类酚类', '花青素', '颜色强度', '色调', 'od280/od315稀释葡萄酒', '脯氨酸']
class_names = ["class_0", "class_1", "class_2"]
dot_data = tree.export_graphviz(clf, out_file='tree.dot', feature_names=feature_name, class_names=class_names, filled=True, rounded=True)
# 处理中文乱码
with open("tree.dot", encoding='utf-8') as f:
    dot_graph = f.read()
    graph = graphviz.Source(dot_graph.replace("helvetica", "FangSong"))
graph.view()
参考

猜你喜欢

转载自blog.csdn.net/zszangy/article/details/113811594