Python决策树鸢尾花类别分类

Python决策树鸢尾花类别分类

引入使用到的模块

#决策树   鸢尾花类别
import sklearn.datasets as dSets
import pandas as pd;
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
import sklearn.tree as sklTree

获取数据

iris = dSets.load_iris();
#print(iris);
data = iris.data;
target = iris.target;
print("data",data);  #属性值
print("target",target);   #结果

输出
在这里插入图片描述
在这里插入图片描述
划分测试集和训练集和

#画分测试机和训练集   训练测试7:3
x_train,x_test,y_train,y_test = train_test_split(data,target,test_size=0.3,random_state=1);

决策树参数说明

1.criterion:string, optional (default="mse")
            它指定了切分质量的评价准则。默认为'mse'(mean squared error)。
2.splitter:string, optional (default="best")
            它指定了在每个节点切分的策略。有两种切分策咯:
            (1).splitter='best':表示选择最优的切分特征和切分点。
            (2).splitter='random':表示随机切分。
3.max_depth:int or None, optional (default=None)
             指定树的最大深度。如果为None,则表示树的深度不限,直到
             每个叶子都是纯净的,即叶节点中所有样本都属于同一个类别,
             或者叶子节点中包含小于min_samples_split个样本。
4.min_samples_split:int, float, optional (default=2)
             整数或者浮点数,默认为2。它指定了分裂一个内部节点(非叶子节点)
             需要的最小样本数。如果为浮点数(0到1之间),最少样本分割数为ceil(min_samples_split * n_samples)
5.min_samples_leaf:int, float, optional (default=1)
             整数或者浮点数,默认为1。它指定了每个叶子节点包含的最少样本数。
             如果为浮点数(0到1之间),每个叶子节点包含的最少样本数为ceil(min_samples_leaf * n_samples)
6.min_weight_fraction_leaf:float, optional (default=0.)
             它指定了叶子节点中样本的最小权重系数。默认情况下样本有相同的权重。
7.max_feature:int, float, string or None, optional (default=None)
             可以是整数,浮点数,字符串或者None。默认为None。
             (1).如果是整数,则每次节点分裂只考虑max_feature个特征。
             (2).如果是浮点数(0到1之间),则每次分裂节点的时候只考虑int(max_features * n_features)个特征。
             (3).如果是字符串'auto',max_features=n_features。
             (4).如果是字符串'sqrt',max_features=sqrt(n_features)。
             (5).如果是字符串'log2',max_features=log2(n_features)。
             (6).如果是None,max_feature=n_feature。
8.random_state:int, RandomState instance or None, optional (default=None)
             (1).如果为整数,则它指定了随机数生成器的种子。
             (2).如果为RandomState实例,则指定了随机数生成器。
             (3).如果为None,则使用默认的随机数生成器。
9.max_leaf_nodes:int or None, optional (default=None)
             (1).如果为None,则叶子节点数量不限。
             (2).如果不为None,则max_depth被忽略。
10.min_impurity_decrease:float, optional (default=0.)
             如果节点的分裂导致不纯度的减少(分裂后样本比分裂前更加纯净)大于或等于min_impurity_decrease,则分裂该节点。
             个人理解这个参数应该是针对分类问题时才有意义。这里的不纯度应该是指基尼指数。
             回归生成树采用的是平方误差最小化策略。分类生成树采用的是基尼指数最小化策略。
             加权不纯度的减少量计算公式为:
             min_impurity_decrease=N_t / N * (impurity - N_t_R / N_t * right_impurity
                                - N_t_L / N_t * left_impurity)
             其中N是样本的总数,N_t是当前节点的样本数,N_t_L是分裂后左子节点的样本数,
             N_t_R是分裂后右子节点的样本数。impurity指当前节点的基尼指数,right_impurity指
             分裂后右子节点的基尼指数。left_impurity指分裂后左子节点的基尼指数。
11.min_impurity_split:float
             树生长过程中早停止的阈值。如果当前节点的不纯度高于阈值,节点将分裂,否则它是叶子节点。
             这个参数已经被弃用。用min_impurity_decrease代替了min_impurity_split。
12.presort: bool, optional (default=False)
             指定是否需要提前排序数据从而加速寻找最优切分的过程。设置为True时,对于大数据集
             会减慢总体的训练过程;但是对于一个小数据集或者设定了最大深度的情况下,会加速训练过程。

使用GridSearchCV自动调参 比如我们调min_samples_leaf和max_depth

from sklearn.model_selection import GridSearchCV
tree_param_grid = {'min_samples_leaf':[1,2,3,4,5,6,7,8,9,10],'max_depth':[2,3,4,]};
grid = GridSearchCV(sklTree.DecisionTreeClassifier(),param_grid=tree_param_grid,cv=5);
grid.fit(x_train,y_train);
print('best_params_',grid.best_params_);
print('best_score_',grid.best_score_);

输出
在这里插入图片描述
所以我们选择 max_depth=3,min_samples_leaf=1,的组合

训练

#训练
decisionTree = sklTree.DecisionTreeClassifier(min_samples_leaf=1,max_depth = 3);
decisionTree.fit(x_train,y_train);  #训练
print(decisionTree);
print(decisionTree.score(x_test,y_test))

输出
在这里插入图片描述
预测 和画图

predict = decisionTree.predict(x_test);#预测

print(predict)
print(y_test)

#获取花卉的两列数据
sepallength = [i[0] for i in x_test];
print('sepallength',sepallength);
sepalwidth = [i[1] for i in x_test];
print('sepalwidth',sepalwidth);
plt.scatter(sepallength,sepalwidth,c=predict,marker='x');
plt.title('DTC');
plt.show();

输出
在这里插入图片描述
在这里插入图片描述

猜你喜欢

转载自blog.csdn.net/weixin_44235109/article/details/106188737