机器学习sklearn(6)决策树分类

from sklearn.tree import DecisionTreeClassifier
import pandas as pd

导入数据集

names=['age', 'prescript', 'astigmatic', 'tearrate', 'class']
data_path = r"C:\Users\Machine Learning\lenses.txt"
lenses = pd.read_csv(data_path, sep='\t', names=names)

print(lenses.head())

     age prescript astigmatic tearrate      class
0  young     myope         no  reduced  no lenses
1  young     myope         no   normal       soft
2  young     myope        yes  reduced  no lenses
3  young     myope        yes   normal       hard
4  young     hyper         no  reduced  no lenses

转换成array格式,拆分特征量和标签

import numpy as np

lenses = np.array(lenses)
lenses_data = lenses[:,:-1]
lenses_target = lenses[:,-1]

DecisionTreeClassifier.fit()不能接收string类型数据,需要先对数据集进行编码

data = np.zeros(shape = lenses_data.shape)
from sklearn import preprocessing
le = preprocessing.LabelEncoder()
for i in range(0, lenses.shape[1]-1):
    le.fit(lenses_data[:,i])
    data[:,i] = le.transform(lenses_data[:,i])

clf = DecisionTreeClassifier(max_depth=4, criterion='entropy')
clf.fit(data, lenses_target)

DecisionTreeClassifier(class_weight=None, criterion='entropy', max_depth=4,
            max_features=None, max_leaf_nodes=None,
            min_impurity_decrease=0.0, min_impurity_split=None,
            min_samples_leaf=1, min_samples_split=2,
            min_weight_fraction_leaf=0.0, presort=False, random_state=None,
            splitter='best')

from sklearn.externals.six import StringIO
import pydotplus

可视化决策树

dot_data = StringIO()
from sklearn import tree
tree.export_graphviz(clf, out_file=dot_data, 
                     feature_names=names[:-1],
                     class_names=clf.classes_,
                     filled=True, rounded=True,
                        special_characters=True)   #filled:颜色填充   rounded:拐角圆滑处理

graph = pydotplus.graph_from_dot_data(dot_data.getvalue())
graph.write_pdf("tree_.pdf") 

True

内容有点多不太好理解,可以对照原始数据一起来看

data_imp = np.column_stack((data, lenses_target))   #不能直接用np.hstack(),因为稀疏矩阵时会报错
data_pd = pd.DataFrame(data_imp, columns=names)
data_pd

#samples表示当前节点样本个数,values表示当前节点不同标签(“hard”,“no lenses”,“soft”,按照英文字母顺序排列)的样本个数,class表示当前节点三个不同标签中样本数最多的标签

根据生成的决策树进行预测

print(clf.predict([[1,0,1,0]])) 

['no lenses']

猜你喜欢

转载自blog.csdn.net/weixin_44530236/article/details/88693359