CART 算法手写数字识别

from sklearn.model_selection import train_test_split
from sklearn import preprocessing
from sklearn.metrics import accuracy_score
from sklearn.datasets import load_digits
from sklearn.svm import SVC
import matplotlib as plt
from sklearn.metrics import confusion_matrix
import seaborn as sns # 数据可视化的包

加载数据

digits = load_digits()
data = digits.data

查看数据集大小

data.shape

数据集介绍

1797个样本,每个样本包括88像素的图像和一个[0, 9]整数的标签。

array矩阵类型数据,保存8
8的图像,里面的元素是float64类型,共有1797张图片
用于显示图片。

获取第一张图片的像素数

print(digits.images[0])

将25%的数据作为测试集,其余作为训练集

train_x, test_x, train_y, test_y = train_test_split(data, digits.target, test_size=0.25, random_state=33)

采用Z-Score规范化

ss = preprocessing.StandardScaler()
train_ss_x = ss.fit_transform(train_x)
test_ss_x = ss.transform(test_x)

CART 算法简单介绍

Classification And Regression Tree,即分类回归树算法,简称CART算法,它是决策树的一种实现,通常决策树主要有三种实现,分别是ID3算法,CART算法和C4.5算法。

CART 算法采用 Gini系数作为标准进行特征分割。

决策树的算法原理大家不理解属于正常,老师还没有讲到。

有兴趣了解的同学可以看一下链接:

https://zhuanlan.zhihu.com/p/30059442

https://zhuanlan.zhihu.com/p/104462031

#训练一个DecisionTree分类器

from sklearn.tree import DecisionTreeClassifier
clf = DecisionTreeClassifier(random_state=0,splitter='best',criterion='gini') # sklearn默认使用基尼Gini系数
clf.fit(train_ss_x,train_y)

predict_y = clf.predict(test_ss_x)
print('CART算法准确率: %0.4lf' % accuracy_score(test_y, predict_y))

猜你喜欢

转载自blog.csdn.net/weixin_44659309/article/details/107054586