Decision tree generation, decision tree visualization, decision tree algorithm api, Titanic passenger survival prediction case code

1. Decision tree algorithm api

  • class sklearn.tree.DecisionTreeClassifier(criterion=’gini’,max_depth=None,random_state=None)
    • criterion : feature selection criteria, "gini" or "entropy", the former represents the Gini coefficient, the latter represents the information gain, the default "gini", which is the CART algorithm
    • min_samples_split : The minimum number of samples required for re-division of internal nodes. This value limits the conditions for subtrees to continue to be divided. If the number of samples of a node is less than min_samples_split, it will not continue to try to select the optimal feature for division. The default is 2. If the sample size is not large, you don't need to care about this value. If the sample size is very large, it is recommended to increase this value. When building a decision tree with 100,000 samples, you can refer to min_samples_split=10
    • min_samples_leaf : The minimum number of samples of a leaf node. This value limits the minimum number of samples of a leaf node. If the number of a leaf node is less than the number of samples, it will be pruned together with sibling nodes. The default is 1, you can enter the integer of the minimum number of samples, or the percentage of the minimum number of samples to the total number of samples. If the sample size is not large, do not need to care about this value. It is recommended to increase this value if the sample size is of very large order of magnitude. 100,000 samples can refer to select min_samples_leaf=5
    • max_depth : The maximum depth of the decision tree, the maximum depth of the decision tree, can not be entered by default, if not entered, the decision tree will not limit the depth of the subtree when building the subtree. Generally speaking, this value can be ignored when there is little data or few features. If the model has a large number of samples and many features, it is recommended to limit the maximum depth. The specific value depends on the distribution of the data. Commonly used values ​​can be between 10-100
    • random_state: Random number seed

2. Case: Titanic Passenger Survival Prediction

Titanic Data: The Titanic and titanic2 data frames describe the survival status of individual passengers on the Titanic. The dataset used here was started by various researchers, including passenger manifests created by many researchers. Edited by Michael A. Findlay. The features in the extracted dataset are ticket category, survival, name, gender, age, etc.

The content and usage process of the Titanic training data train.csv are as follows

 The complete code is as follows

import pandas as pd
import numpy as np
from sklearn.feature_extraction import DictVectorizer
from sklearn.model_selection import train_test_split
from sklearn.tree import DecisionTreeClassifier, export_graphviz

data = pd.read_csv('../data/train.csv')
data
------------------------------------------
data.describe()
------------------------
# 数据基本处理,确定特征值、目标值
x = data[["Pclass", "Age", "Sex"]]
x
------------------------
y = data["Survived"]
y.head()
------------------------
# 缺失值需要处理,将特征当中有类别的这些特征进行字典特征抽取
x['Age'].fillna(value=x['Age'].mean(), inplace=True)
x
-------------------------------------------
# 数据集划分
x_train, x_test, y_train, y_test = train_test_split(x, y, random_state=22, test_size=0.2)
x.head()
-------------------------------------------
# 特征工程(字典特征抽取)
# x.to_dict(orient="records") 需要将数组特征转换成字典数据
x_train = x_train.to_dict(orient="records")
x_test = x_test.to_dict(orient="records")
x_train
-------------------------------------------
# 特征中出现类别符号,需要进行one-hot编码处理(DictVectorizer)
transfer = DictVectorizer(sparse=False)  # 实例化一个转换器类
x_train = transfer.fit_transform(x_train)   # 调用fit_transform方法输入数据并转换,返回sparse矩阵
x_test = transfer.fit_transform(x_test)
x_train
-------------------------------------------
# 机器学习(决策树),决策树模型训练
# 决策树API当中,如果没有指定max_depth那么会根据信息熵的条件直到最终结束,这里指定树的深度来进行限制树的大小
estimator = DecisionTreeClassifier(criterion="entropy", max_depth=5)
estimator.fit(x_train, y_train)
-------------------------------------------
# 模型评估
estimator.predict(x_test)  # 预测值
-------------------------
estimator.score(x_test, y_test)   # 准确率

3. Decision tree visualization

Save the structure of the tree to a dot file

  • sklearn.tree.export_graphviz(): This function can export DOT format
    • tree.export_graphviz(estimator,out_file='tree.dot’,feature_names=[‘’,’’])
  • Pros: easy to understand and explain, tree visualization
  • Cons: Decision tree learners can create overly complex trees that don't generalize data well, prone to overfitting
  • Improve:
    • pruning cart algorithm
    • Random Forest (a type of ensemble learning)

For important decisions of enterprises, due to the good analysis ability of decision trees, they are widely used in the decision-making process, and features can be selected

Continue to execute the following code based on the previous code

export_graphviz(estimator, out_file="../data/tree.dot", feature_names=['Age', 'Pclass', 'male', 'female'])

 As follows, the tree.dot file will be generated in the data directory

 The content of the tree.dot file is as follows

digraph Tree {
node [shape=box, fontname="helvetica"] ;
edge [fontname="helvetica"] ;
0 [label="female <= 0.5\nentropy = 0.96\nsamples = 712\nvalue = [439, 273]"] ;
1 [label="Pclass <= 2.5\nentropy = 0.802\nsamples = 250\nvalue = [61, 189]"] ;
0 -> 1 [labeldistance=2.5, labelangle=45, headlabel="True"] ;
2 [label="Age <= 27.5\nentropy = 0.264\nsamples = 134\nvalue = [6, 128]"] ;
1 -> 2 ;
3 [label="Age <= 23.5\nentropy = 0.496\nsamples = 46\nvalue = [5, 41]"] ;
2 -> 3 ;
4 [label="Age <= 2.5\nentropy = 0.206\nsamples = 31\nvalue = [1, 30]"] ;
3 -> 4 ;
5 [label="entropy = 1.0\nsamples = 2\nvalue = [1, 1]"] ;
4 -> 5 ;
6 [label="entropy = 0.0\nsamples = 29\nvalue = [0, 29]"] ;
4 -> 6 ;
7 [label="Age <= 24.5\nentropy = 0.837\nsamples = 15\nvalue = [4, 11]"] ;
3 -> 7 ;
8 [label="entropy = 0.592\nsamples = 7\nvalue = [1, 6]"] ;
7 -> 8 ;
9 [label="entropy = 0.954\nsamples = 8\nvalue = [3, 5]"] ;
7 -> 9 ;
10 [label="Age <= 56.5\nentropy = 0.09\nsamples = 88\nvalue = [1, 87]"] ;
2 -> 10 ;
11 [label="entropy = 0.0\nsamples = 82\nvalue = [0, 82]"] ;
10 -> 11 ;
12 [label="Pclass <= 1.5\nentropy = 0.65\nsamples = 6\nvalue = [1, 5]"] ;
10 -> 12 ;
13 [label="entropy = 0.0\nsamples = 5\nvalue = [0, 5]"] ;
12 -> 13 ;
14 [label="entropy = 0.0\nsamples = 1\nvalue = [1, 0]"] ;
12 -> 14 ;
15 [label="Age <= 38.5\nentropy = 0.998\nsamples = 116\nvalue = [55, 61]"] ;
1 -> 15 ;
16 [label="Age <= 1.5\nentropy = 0.988\nsamples = 108\nvalue = [47, 61]"] ;
15 -> 16 ;
17 [label="entropy = 0.0\nsamples = 4\nvalue = [0, 4]"] ;
16 -> 17 ;
18 [label="Age <= 32.5\nentropy = 0.993\nsamples = 104\nvalue = [47, 57]"] ;
16 -> 18 ;
19 [label="entropy = 0.997\nsamples = 100\nvalue = [47, 53]"] ;
18 -> 19 ;
20 [label="entropy = 0.0\nsamples = 4\nvalue = [0, 4]"] ;
18 -> 20 ;
21 [label="entropy = 0.0\nsamples = 8\nvalue = [8, 0]"] ;
15 -> 21 ;
22 [label="Age <= 13.0\nentropy = 0.684\nsamples = 462\nvalue = [378, 84]"] ;
0 -> 22 [labeldistance=2.5, labelangle=-45, headlabel="False"] ;
23 [label="Pclass <= 2.5\nentropy = 0.948\nsamples = 30\nvalue = [11, 19]"] ;
22 -> 23 ;
24 [label="entropy = 0.0\nsamples = 11\nvalue = [0, 11]"] ;
23 -> 24 ;
25 [label="Age <= 0.71\nentropy = 0.982\nsamples = 19\nvalue = [11, 8]"] ;
23 -> 25 ;
26 [label="entropy = 0.0\nsamples = 1\nvalue = [0, 1]"] ;
25 -> 26 ;
27 [label="Age <= 11.5\nentropy = 0.964\nsamples = 18\nvalue = [11, 7]"] ;
25 -> 27 ;
28 [label="entropy = 0.937\nsamples = 17\nvalue = [11, 6]"] ;
27 -> 28 ;
29 [label="entropy = 0.0\nsamples = 1\nvalue = [0, 1]"] ;
27 -> 29 ;
30 [label="Pclass <= 1.5\nentropy = 0.611\nsamples = 432\nvalue = [367, 65]"] ;
22 -> 30 ;
31 [label="Age <= 60.5\nentropy = 0.888\nsamples = 95\nvalue = [66, 29]"] ;
30 -> 31 ;
32 [label="Age <= 47.5\nentropy = 0.922\nsamples = 83\nvalue = [55, 28]"] ;
31 -> 32 ;
33 [label="entropy = 0.874\nsamples = 68\nvalue = [48, 20]"] ;
32 -> 33 ;
34 [label="entropy = 0.997\nsamples = 15\nvalue = [7, 8]"] ;
32 -> 34 ;
35 [label="Age <= 75.5\nentropy = 0.414\nsamples = 12\nvalue = [11, 1]"] ;
31 -> 35 ;
36 [label="entropy = 0.0\nsamples = 11\nvalue = [11, 0]"] ;
35 -> 36 ;
37 [label="entropy = 0.0\nsamples = 1\nvalue = [0, 1]"] ;
35 -> 37 ;
38 [label="Age <= 32.25\nentropy = 0.49\nsamples = 337\nvalue = [301, 36]"] ;
30 -> 38 ;
39 [label="Age <= 30.75\nentropy = 0.535\nsamples = 254\nvalue = [223, 31]"] ;
38 -> 39 ;
40 [label="entropy = 0.483\nsamples = 239\nvalue = [214, 25]"] ;
39 -> 40 ;
41 [label="entropy = 0.971\nsamples = 15\nvalue = [9, 6]"] ;
39 -> 41 ;
42 [label="Age <= 41.5\nentropy = 0.328\nsamples = 83\nvalue = [78, 5]"] ;
38 -> 42 ;
43 [label="entropy = 0.156\nsamples = 44\nvalue = [43, 1]"] ;
42 -> 43 ;
44 [label="entropy = 0.477\nsamples = 39\nvalue = [35, 4]"] ;
42 -> 44 ;
}

You can copy the content in the tree.dot file to the Webgraphviz website for execution to realize the visualization of the decision tree. When I run it, the website seems to be invalid and cannot be loaded, as follows

Use the following method to execute, see: graphviz installation and use, decision tree generation

Generate a decision tree as follows

Learning to navigate: http://xqnav.top/

Guess you like

Origin blog.csdn.net/qq_43874317/article/details/128586322