[Study Notes] [Machine Learning] 6. [Next] Decision Tree Algorithm (Entropy, Information Gain (Rate), Gini Value (Index), CART Pruning, Feature Engineering Feature Extraction, Regression Decision Tree)

5. Decision tree algorithm API

learning target:

  • Know the specific use of the decision tree algorithm API

sklearn.tree.DecisionTreeClassifier(criterion='gini', splitter='best', 
                                    max_depth=None, min_samples_split=2, 
                                    min_samples_leaf=1, 
                                    min_weight_fraction_leaf=0.0, 
                                    max_features=None, random_state=None, 
                                    max_leaf_nodes=None, 
                                    min_impurity_decrease=0.0, 
                                    class_weight=None, ccp_alpha=0.0)
  • Role : sklearn.tree.DecisionTreeClassifierIt is a decision tree classifier that can be used for classification tasks. A decision tree is a nonparametric supervised learning method that predicts the value of a target variable by learning simple decision rules from data features. A decision tree can be viewed as a piecewise constant approximation.

  • Parameters :

    • criterion: Function used to measure split quality. Supported standards are:
      • "gini" (Gini impurity) is the default parameter, the CART algorithm
      • "entropy" (Shannon information gain)
    • splitter: The strategy used to select splits for each node. Supported strategies are "best" (choose the best split) and "random" (choose the best random split).
    • max_depth: The maximum depth of the decision tree. If yes None, the node will be expanded until all leaves are pure or until all leaves contain less than min_samples_splitsamples (if not entered, the decision tree will not limit the depth of the subtree when building the subtree).
      • Generally speaking, this value can be ignored when there is little data or few features. If the model has a large number of samples and many features, it is recommended to limit the maximum depth. The specific value depends on the distribution of the data. Commonly used values ​​can be between 10-100
    • min_samples_split: The minimum number of samples required to split an internal node (the minimum number of samples required to split an internal node).
      • This value limits the conditions for the subtree to continue to be divided. If the number of samples of a node is less min_samples_splitthan The default is 2.
      • If the sample size is not large, do not need to care about this value. It is recommended to increase this value if the sample size is of very large order of magnitude. 100,000 sample items were selected when building the decision tree min_samples_split=10, for reference only.
    • min_samples_leaf: The minimum number of samples required for a leaf node.
      • This value limits the minimum number of samples of leaf nodes. If the number of leaf nodes is less than the number of samples, it will be pruned together with sibling nodes. The default is 1, you can enter the integer of the minimum number of samples, or the percentage of the minimum number of samples to the total number of samples. If the sample size is not large, do not need to care about this value.
      • It is recommended to increase this value if the sample size is of very large order of magnitude. The value of min_samples_leaf is 5 for 100,000 samples, which is for reference only.
    • min_weight_fraction_leaf: The ratio of the minimum weighted sample number required by the leaf node to the total weight sum.
    • max_features: The number of features to consider when finding the best split.
    • random_state: Random number generator seed.
    • max_leaf_nodes: Maximum number of leaf nodes.
    • min_impurity_decrease: A node will be split if splitting would result in a reduction in impurity greater than or equal to this value.
    • class_weight: category weight.
  • Return value : Returns a decision tree classifier object that can be used to fit data, predict data, and perform other operations.

Official website details page

6. Case: Titanic Passenger Survival Prediction

learning target:

  • Further master the specific use of the decision tree algorithm API through the case

6.1 Case background

The sinking of the Titanic is one of the most notorious shipwrecks in history. On April 15, 1912, during her maiden voyage, the Titanic sank after colliding with an iceberg, killing 1,502 of her 2,224 passengers and crew. The sensational tragedy shocked the international community and led to better safety regulations for ships. One of the causes of shipwrecks is that there are not enough lifeboats for passengers and crew. Although there is some element of luck involved in surviving the sinking, some people are more likely to survive than others, such as women, children and the upper class.

The data of "1,502 deaths out of 2,224 passengers and crew" mentioned in the background is not accurate. According to Wikipedia, there were 2,224 people on board the Titanic, including passengers and crew, while the death toll ranged from 1,490 to 1,635.

In this case, we asked to complete an analysis of who was likely to survive. Asked to apply machine learning tools to predict which passengers survived the tragedy.

Case: https://www.kaggle.com/c/titanic/overview

Features in our extracted dataset include ticket category, alive or not, ride frequency, age, home address/destination, room, boat, and gender, etc.

insert image description here

Data (currently inaccessible): http://biostat.mc.vanderbilt.edu/wiki/pub/Main/DataSets/titanic.txt
Data (accessible, but slightly discrepant): https://github.com/YBIFoundation /Dataset/blob/main/Titanic.txt

Property description :

  • pclass: cabin class ( 1, 2, 3)
  • survived: survived ( 0, 1)
  • name: name
  • sex: gender
  • age: age
  • sibsp: number of siblings/spouses on board
  • parch: number of parents/children on board
  • ticket: ticket number
  • fare: ticket price
  • cabin: cabin number
  • embarked: port of embarkation
  • boat: lifeboat number
  • body: body identification number
  • home.dest: home address/destination

Observation data obtained:

  • pclass: cabin class ( 1, 2, 3) is representative of socioeconomic class
  • The age data is missing

6.2 Step Analysis

  1. retrieve data
  2. Basic Data Processing
    1. Determine eigenvalues ​​and target values
    2. Missing value handling
    3. Dataset partition
  3. Feature engineering (dictionary feature extraction)
  4. Machine Learning (Decision Trees)
  5. model evaluation

6.3 Code implementation

Zero, import module

import pandas as pd
import numpy as np
from sklearn.feature_extraction import DictVectorizer
from sklearn.model_selection import train_test_split
from sklearn.tree import DecisionTreeClassifier, export_graphviz

1. Get data

# 1. 获取数据
titanic = pd.read_csv("../data/titanic.txt")
titanic
titanic.describe()

insert image description here

insert image description here

2. Basic data processing

2.1 Determining eigenvalues ​​and target values

We first analyze each attribute to determine which data is our feature value and which data is our target value:

feature describe Affiliation
⭐️pclass cabin class ( 1, 2, 3) Might be useful, set as eigenvalue
⭐️survived survived ( 0, 1) Obviously the target value
name Name It doesn't matter
⭐️sex gender Gender is mentioned in the material, as a feature value
⭐️age age Age is mentioned in the material, as a feature value
sibsp Number of siblings/spouses on board It doesn't matter
respect Number of parents/children on board It doesn't matter
ticket ticket number It doesn't matter
fare Ticket price It doesn't matter
cabin cabin number It doesn't matter
embarked port of embarkation It doesn't matter
boat lifeboat number It doesn't matter
body body identification number It doesn't matter
home.dest Home Address/Destination Nevermind (no passengers disembarked before sinking)
## 2.1 确定特征值和目标值
x = titanic[["pclass", "sex", "age"]]
y = titanic["survived"]

insert image description here

2.2 Missing value processing

## 2.2 缺失值处理
# 缺失值需要处理,将特征中有类别的特征进行字典特征抽取
# 因为缺失值为 N/A,所以我们可以直接使用 .isnull() 方法来判断是否存在缺失值
x.loc[x['age'].isnull(), 'age'] = x['age'].mean()

This line of code uses .locto modify the data in the DataFrame. .locis an indexer that allows us to access the data in the DataFrame by label.

In this example, .locthe first argument is x['age'].isnull(), which returns a Boolean Series indicating whether the 'age' column for each row has missing values. The second parameter is 'age', indicating that we want to modify the value of the 'age' column.

So what this line of code does is replace the missing values ​​in the 'age' column with the average value of the 'age' column.

Note : It cannot be used here df.dropna, because we have many values ​​​​that are N/A. If they are discarded, the corresponding rows will also be deleted, which will cause us to delete a large amount of data. So we're going to use replace instead of delete.

For details, please refer to the handling of missing values ​​in https://blog.csdn.net/weixin_44878336/article/details/130027046 .

2.3 Dataset division

## 2.3 数据集划分
x_train, x_test, y_train, y_test = train_test_split(x, y, random_state=22)

3. Feature engineering (dictionary feature extraction)

The category symbol appears in the feature, and one-hot encoding processing is required ( DictVectorizerimplemented by using a class), but DictVectorizerthe class requires the input to be a dictionary, and our current data is a DataFrame, so we need to convert the DataFrame to a Dict, which is used here df.to_dict()方法.

`x.to_dict(orient="records")`  # 需要将数组特征转换成字典数据。

df.to_dict(orient="records")is a method that converts a DataFrame to a dictionary. It takes orienta parameter to specify how the conversion should be done.

At that orient="records"time , each row of data will be converted into a dictionary, where the key is the column name and the value is the value of the corresponding column of the row. Eventually, all row dictionaries are combined into a list, which is returned as the result.

For example, suppose we have a DataFrame as follows:

   A  B
0  1  2
1  3  4

df.to_dict(orient="records")When we call , we get the following result:

[{'A': 1, 'B': 2}, {'A': 3, 'B': 4}]

# 3. 特征工程(字典特征抽取)

## 3.1 实例化一个字典转换器类
transfer = DictVectorizer(sparse=False)  # 不用输出稀疏矩阵

## 3.2 将DataFrame转换为字典数据
x_train = x_train.to_dict(orient="records")
x_test = x_test.to_dict(orient="records")

print("转换为字典后的x_train为:", x_train)
print("\r\n转换为字典后的x_test为:", x_test)

## 3.3 特征转换
x_train = transfer.fit_transform(x_train)
# 注意:在测试数据上,应该使用与训练数据相同的转换方式,因此应该使用 `transform` 方法,
# 而不是 `fit_transform` 方法。`transform` 方法只进行转换,不会改变转换器的拟合结果。
x_test = transfer.transform(x_test)

print("\r\nx_train:\r\n", x_train)
print("\r\nx_test:\r\n", x_test)

Print result:

转换为字典后的x_train为: [{
    
    'pclass': 1, 'sex': 'female', 'age': 39.0}, {
    
    'pclass': 2, 'sex': 'female', 'age': 19.0}, ...]

转换为字典后的x_test为: [{
    
    'pclass': 3, 'sex': 'male', 'age': 11.0}, {
    
    'pclass': 3, 'sex': 'male', 'age': 29.881137667304014}, {
    
    'pclass': 3, 'sex': 'male', ...]

x_train:
 [[39.          1.          1.          0.        ]
 [19.          2.          1.          0.        ]
 [27.          3.          0.          1.        ]
 ...
 [29.88113767  3.          0.          1.        ]
 [24.          1.          0.          1.        ]
 [17.          3.          0.          1.        ]]

x_test:
 [[11.          3.          0.          1.        ]
 [29.88113767  3.          0.          1.        ]
 [ 4.          3.          0.          1.        ]
 ...
 [27.          2.          1.          0.        ]
 [49.          1.          0.          1.        ]
 [16.          1.          1.          0.        ]]

4. Machine Learning (Decision Trees)

In the decision tree API, if not specified max_depth, it will be based on the condition of information entropy until the final end. Here we can specify the depth of the tree to limit the size of the tree.

# 4. 机器学习
## 4.1 定义模型
estimator = DecisionTreeClassifier(criterion="entropy", max_depth=5)

# 4.2 模型训练
estimator.fit(x_train, y_train)
print("模型训练完成!")

5. Model Evaluation

# 5. 模型评估
score = estimator.score(x_test, y_test)
print(f"模型准确率为:{
      
      score * 100:.2f}%")

res = estimator.predict(x_test)
print("\r\n测试集预测结果为:\r\n", res)

result:

模型准确率为:75.61%

测试集预测结果为:
 [0 0 0 1 1 0 0 0 0 1 1 0 0 1 0 0 1 0 1 0 0 1 1 0 0 1 1 0 1 1 0 0 0 1 1 1 1
 0 0 0 1 0 0 0 0 1 1 0 1 0 1 1 0 0 0 1 1 1 0 0 1 0 1 0 0 0 0 0 0 1 0 1 1 0
 0 0 0 1 0 1 1 0 0 0 0 0 1 1 0 0 0 0 1 1 0 0 1 0 1 0 0 0 1 1 1 0 1 0 1 0 0
 0 1 1 0 0 1 1 1 1 0 0 1 0 0 1 0 1 1 0 0 0 0 0 1 0 0 0 1 1 0 0 0 1 0 1 0 0
 1 0 0 0 1 1 0 0 1 0 0 1 0 0 1 1 1 0 1 0 0 0 1 1 0 0 1 0 1 0 0 0 0 0 1 0 0
 0 0 0 0 0 1 0 1 0 0 0 0 0 1 0 1 1 0 1 0 0 0 0 0 1 0 0 1 1 0 0 0 0 0 0 1 0
 1 0 0 0 0 0 1 0 0 0 1 0 1 0 0 0 1 0 1 1 1 1 1 0 1 1 0 0 0 0 1 0 0 0 0 0 0
 0 0 0 0 1 0 1 1 1 1 0 0 0 1 1 0 0 0 1 0 0 0 0 0 1 1 0 1 0 0 0 0 1 1 0 0 1
 0 0 1 1 1 0 0 0 1 0 0 1 0 1 0 1 0 1 1 1 1 0 1 1 1 0 0 0 1 1 0 1]

【Parameter Adjustment】Change criterion

# 4. 机器学习
## 4.1 定义模型
estimator = DecisionTreeClassifier(criterion="gini", max_depth=5)

# 4.2 模型训练
estimator.fit(x_train, y_train)
print("模型训练完成!\r\n")

# 5. 模型评估
score = estimator.score(x_test, y_test)
print(f"模型准确率为:{
      
      score * 100:.2f}%")

res = estimator.predict(x_test)
print("\r\n测试集预测结果为:\r\n", res)

result:

模型训练完成!

模型准确率为:75.30%

测试集预测结果为:
 [0 0 0 1 1 0 0 0 0 1 1 0 0 1 0 0 1 0 1 0 0 1 1 0 0 1 1 0 1 1 0 0 0 1 1 0 1
 0 0 0 1 0 0 0 0 1 1 0 1 0 1 1 0 0 0 1 1 1 0 0 1 0 1 0 0 0 0 0 0 1 0 1 1 0
 0 0 0 1 0 1 1 0 0 0 0 0 1 1 0 0 0 0 1 1 0 0 1 0 1 0 0 0 1 1 1 0 0 0 1 0 0
 0 0 1 0 0 1 1 1 1 0 0 1 0 0 1 0 1 1 0 0 0 0 0 1 0 0 0 1 1 0 0 0 1 0 1 0 0
 1 0 0 0 1 1 0 0 1 0 0 1 0 0 1 1 1 0 1 0 0 0 1 1 0 0 1 0 1 0 0 0 0 0 1 0 0
 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 1 1 0 1 0 0 0 0 1 1 0 0 1 1 0 0 0 0 0 0 1 0
 1 0 0 0 0 0 1 0 0 0 1 0 1 0 0 0 1 0 1 1 0 1 1 0 1 0 0 0 0 0 1 0 0 0 0 0 0
 0 0 1 0 1 0 1 0 1 1 0 0 0 1 1 0 0 0 1 0 0 0 0 0 1 1 0 1 0 0 0 0 0 1 0 0 1
 0 0 1 1 1 0 0 0 1 0 0 1 0 0 0 1 0 1 1 1 1 1 1 1 1 0 0 0 1 1 0 1]

[Parameter adjustment] Change the depth of the decision tree

import matplotlib.pyplot as plt
from pylab import mpl
# 设置中文字体
mpl.rcParams["font.sans-serif"] = ["SimHei"]
# 设置正常显示符号
mpl.rcParams["axes.unicode_minus"] = False


# 4. 机器学习

tree_depth = list(range(1, 101))
acc_test_lst = []
acc_train_lst = []

for depth in tree_depth:

    ## 4.1 定义模型
    estimator = DecisionTreeClassifier(criterion="gini", max_depth=depth)

    # 4.2 模型训练
    estimator.fit(x_train, y_train)

    # 5. 模型评估
    score_test = estimator.score(x_test, y_test)
    score_train = estimator.score(x_train, y_train)
    acc_test_lst.append(score_test)
    acc_train_lst.append(score_train)
    
plt.figure(dpi=300)
plt.plot(tree_depth, acc_test_lst, label="测试集")
plt.plot(tree_depth, acc_train_lst, label="训练集")
plt.title("决策树深度与测试集准确率的关系")
plt.xlabel("决策树深度")
plt.ylabel("Accuracy")
plt.legend()
plt.show()

result:

insert image description here

According to the image, the accuracy of the training set and the accuracy of the test set increase with the depth of the tree, but after a certain depth, the accuracy tends to be stable.

This could mean that, after a certain tree depth, the model is complex enough to fit the training data well. However, if you continue to increase the depth of the tree, it may cause overfitting problems.

Therefore, we can try to choose an appropriate tree depth to avoid over-fitting problems while ensuring the accuracy of the test set.

6.4 Decision Tree Visualization

6.4.1 Save tree structure to dot file

sklearn.tree.export_graphviz()

Function : Function used to export decision tree to GraphViz DOT format. This function generates a GraphViz representation of the decision tree and then writes it out_file. Once exported, graphics renderings can be generated using, for example:

$ dot -Tps tree.dot -o tree.ps (PostScript 格式)
$ dot -Tpng tree.dot -o tree.png (PNG 格式)

Parameters :

  • decision_tree: The decision tree to export to GraphViz.
  • out_file: The handle or name of the output file. If None, the result is returned as a string.
  • max_depth: Indicates the maximum depth. If None, the tree is fully built.
  • feature_names: The name of each feature. If None, generic names are used ("x[0]", "x[1]", ...).
  • class_names: The name of each target class in ascending order. Only relevant for classification, does not support multiple outputs. If True, display a symbolic representation of the class name.
  • label: Whether to display informational labels such as impurities. Options include:
    • 'all' display at each node
    • 'root' is displayed only at the top root node
    • 'none' is not displayed on any nodes.
  • filled: When set to True, draws nodes to indicate dominant categories for classification, extreme values ​​for regression values, or purity for multi-output nodes.
  • leaves_parallel: When set to True, draws all leaf nodes at the bottom of the tree.
  • impurity: When set to True, show impurities at each node.
  • node_ids: When set to True, display the ID number on each node.
  • proportion: When set to True, changes the display of 'values' and/or 'samples' to scales and percentages.
  • rotate: When set to True, orients the tree left-to-right instead of top-to-bottom.
  • rounded: When set to True, draws a node box with rounded corners.
  • special_characters: When set to False, special characters are ignored for PostScript compatibility.
  • precision: Number of bits of precision for floating point numbers in the impurity, threshold, and value properties of each node.
  • Return Value : out_fileIf None, returns a string representation of the input tree in GraphViz point format.

Example:

feature_names_lst = transfer.get_feature_names_out()
export_graphviz(estimator, out_file="../data/decision_tree.dot", 
                feature_names=['age', 'pclass', 'sex=女性', 'sex=男性'],
                fontname='Microsoft YaHei')

Note :

  • fontname='Microsoft YaHei'It is to prevent Chinese from being displayed
  • If we don't know the name of the feature, we can transfer.get_feature_names_out()get the list of features

At this point the following files are generated:

insert image description here

Because we limited the depth of the decision tree to only 5 layers, the tree is very shallow and the corresponding file is small.

The contents of the dot file are as follows:

digraph Tree {
node [shape=box, fontname="Microsoft YaHei"] ;
edge [fontname="Microsoft YaHei"] ;
0 [label="sex=男性 <= 0.5\nentropy = 0.955\nsamples = 981\nvalue = [613, 368]"] ;
1 [label="pclass <= 2.5\nentropy = 0.829\nsamples = 340\nvalue = [89, 251]"] ;
0 -> 1 [labeldistance=2.5, labelangle=45, headlabel="True"] ;
2 [label="pclass <= 1.5\nentropy = 0.351\nsamples = 182\nvalue = [12, 170]"] ;
1 -> 2 ;
3 [label="age <= 8.0\nentropy = 0.242\nsamples = 100\nvalue = [4, 96]"] ;
2 -> 3 ;
4 [label="entropy = 0.0\nsamples = 1\nvalue = [1, 0]"] ;
3 -> 4 ;
5 [label="age <= 62.5\nentropy = 0.196\nsamples = 99\nvalue = [3, 96]"] ;
3 -> 5 ;
6 [label="entropy = 0.147\nsamples = 95\nvalue = [2, 93]"] ;
5 -> 6 ;
7 [label="entropy = 0.811\nsamples = 4\nvalue = [1, 3]"] ;
5 -> 7 ;
8 [label="age <= 17.5\nentropy = 0.461\nsamples = 82\nvalue = [8, 74]"] ;
2 -> 8 ;
9 [label="entropy = 0.0\nsamples = 13\nvalue = [0, 13]"] ;
8 -> 9 ;
10 [label="age <= 44.5\nentropy = 0.518\nsamples = 69\nvalue = [8, 61]"] ;
8 -> 10 ;
11 [label="entropy = 0.561\nsamples = 61\nvalue = [8, 53]"] ;
10 -> 11 ;
12 [label="entropy = 0.0\nsamples = 8\nvalue = [0, 8]"] ;
10 -> 12 ;
13 [label="age <= 0.875\nentropy = 1.0\nsamples = 158\nvalue = [77, 81]"] ;
1 -> 13 ;
14 [label="entropy = 0.0\nsamples = 3\nvalue = [0, 3]"] ;
13 -> 14 ;
15 [label="age <= 46.0\nentropy = 1.0\nsamples = 155\nvalue = [77, 78]"] ;
13 -> 15 ;
16 [label="age <= 36.5\nentropy = 1.0\nsamples = 153\nvalue = [77, 76]"] ;
15 -> 16 ;
17 [label="entropy = 0.999\nsamples = 140\nvalue = [67, 73]"] ;
16 -> 17 ;
18 [label="entropy = 0.779\nsamples = 13\nvalue = [10, 3]"] ;
16 -> 18 ;
19 [label="entropy = 0.0\nsamples = 2\nvalue = [0, 2]"] ;
15 -> 19 ;
20 [label="age <= 12.5\nentropy = 0.686\nsamples = 641\nvalue = [524, 117]"] ;
0 -> 20 [labeldistance=2.5, labelangle=-45, headlabel="False"] ;
21 [label="pclass <= 2.5\nentropy = 0.964\nsamples = 36\nvalue = [14, 22]"] ;
20 -> 21 ;
22 [label="entropy = 0.0\nsamples = 12\nvalue = [0, 12]"] ;
21 -> 22 ;
23 [label="age <= 0.585\nentropy = 0.98\nsamples = 24\nvalue = [14, 10]"] ;
21 -> 23 ;
24 [label="entropy = 0.0\nsamples = 1\nvalue = [0, 1]"] ;
23 -> 24 ;
25 [label="age <= 2.5\nentropy = 0.966\nsamples = 23\nvalue = [14, 9]"] ;
23 -> 25 ;
26 [label="entropy = 0.592\nsamples = 7\nvalue = [6, 1]"] ;
25 -> 26 ;
27 [label="entropy = 1.0\nsamples = 16\nvalue = [8, 8]"] ;
25 -> 27 ;
28 [label="pclass <= 1.5\nentropy = 0.627\nsamples = 605\nvalue = [510, 95]"] ;
20 -> 28 ;
29 [label="age <= 53.5\nentropy = 0.859\nsamples = 131\nvalue = [94, 37]"] ;
28 -> 29 ;
30 [label="age <= 15.0\nentropy = 0.912\nsamples = 107\nvalue = [72, 35]"] ;
29 -> 30 ;
31 [label="entropy = 0.0\nsamples = 1\nvalue = [0, 1]"] ;
30 -> 31 ;
32 [label="entropy = 0.905\nsamples = 106\nvalue = [72, 34]"] ;
30 -> 32 ;
33 [label="age <= 75.5\nentropy = 0.414\nsamples = 24\nvalue = [22, 2]"] ;
29 -> 33 ;
34 [label="entropy = 0.258\nsamples = 23\nvalue = [22, 1]"] ;
33 -> 34 ;
35 [label="entropy = 0.0\nsamples = 1\nvalue = [0, 1]"] ;
33 -> 35 ;
36 [label="age <= 45.25\nentropy = 0.536\nsamples = 474\nvalue = [416, 58]"] ;
28 -> 36 ;
37 [label="age <= 44.5\nentropy = 0.56\nsamples = 443\nvalue = [385, 58]"] ;
36 -> 37 ;
38 [label="entropy = 0.555\nsamples = 442\nvalue = [385, 57]"] ;
37 -> 38 ;
39 [label="entropy = 0.0\nsamples = 1\nvalue = [0, 1]"] ;
37 -> 39 ;
40 [label="entropy = 0.0\nsamples = 31\nvalue = [31, 0]"] ;
36 -> 40 ;
}

6.4.2 Decision Tree Visualization

Displaying the decision tree in text form is obviously not intuitive enough, so other tools can be used, here are two ways:

  1. Use Graphviz to visualize .dotthe file.
  2. Use a third-party website to visualize .dotthe file.

6.4.2.1 Visualization using Graphviz

We can use Graphviz to convert .dota file to an image format such as pngor jpg. Enter commands on the command line to complete the conversion, for example:

# 生成.png图片
dot -Tpng your_file.dot -o output.png

# 生成.jpg图片
dot -Tjpg your_file.dot -o output.jpg

in:

  • -Tpngand -Tjpgare options in the Graphviz command-line tool to specify the format of the output file.
    • -Tpng` indicates that the format of the output file is PNG
    • -TjpgIndicates that the format of the output file is JPG
    • Different options can be used to generate image files in different formats
      • For example -Tgiffor generating GIF files
      • -TpdfFor generating PDF files etc.
  • your_file.dotis .dotthe name of the file
  • output.pngis the name of the image file to be generated

Note :

  • -TThe output file format specified by the option should be consistent with the extension of the output file, so as to ensure that the generated image file format is correct.
  • If -Tthe output file format specified by the option is inconsistent with the extension of the output file, the generated image file may not be opened or displayed normally. Therefore, it is recommended that you ensure that -Tthe output file format specified by the option matches the output file extension.

The resulting picture is shown below.

insert image description here

Because we limited the depth of the decision tree to only 5 layers, the tree is very shallow

6.4.2.2 Using third-party websites for visualization

Website address: http://webgraphviz.com

insert image description here

We can .dotcopy the content of the decision tree file just now to this website for visual display, and the results are as follows:

insert image description here

Because we limited the depth of the decision tree to only 5 layers, the tree is very shallow


Summary :

  • Case process analysis【Understand】
    1. retrieve data
    2. Basic Data Processing
      1. Determine eigenvalues, target values
      2. Missing value handling
      3. Dataset partition
    3. Feature engineering (dictionary feature extraction)
    4. Machine Learning (Decision Trees)
    5. model evaluation
  • Decision tree visualization [understand]
    • Decision tree export:sklearn.tree.export_graphviz()
    • Decision Tree Visualization
      • Command Line:dot -T图片格式 文件名.dot -o 输出图片名.图片格式
      • Third-party website: http://webgraphviz.com

7. Regression decision tree

learning target:

  • Know the implementation principle of regression decision tree

As mentioned earlier, regarding data types, we can mainly divide them into two categories, ① continuous data and ② discrete data.

In the face of different data, decision trees can also be divided into two types:

  1. Classification decision tree : mainly used to deal with discrete data
  2. Regression decision tree : mainly used for processing continuous data

Continuous data is mainly used for regression; discrete data is mainly used for classification

7.1 Principle overview

Whether it is a regression decision tree or a classification decision tree, there are two core problems:

  1. How to choose partition nodes?
  2. How to decide the output value of the leaf node?

A regression tree corresponds to a partition node of the input space (ie feature space) and the output value on the partition unit. In the classification tree, we use the method in information theory to select the best dividing point by calculation. In regression trees, a heuristic approach is used.

If we have nnn features, each feature hassi ( i ∈ ( 1 , n ) ) s_i(i \in (1, n))si(i(1,n )) values, then we traverse all the features, try all the values ​​of the feature, and divide the space until we get the featurejjThe value of j is sss , so that the loss function is minimized, so that a division point is obtained. The formula describing this process is as follows:

min ⁡ j s [ min ⁡ c 1   L ( y i , c i ) + min ⁡ c 2   L ( y i , c 2 ) ] \underset{js}{\min}[\underset{c_1}{\min}\ \mathcal{L}(y_i, c_i) + \underset{c_2} {\min} \ \mathcal{L}(y_i, c_2)] jsmin[c1min L ( yi,ci)+c2min L ( yi,c2)]

in:

  • n n n represents the number of features
  • s i s_i siIndicates the iiThe number of values ​​for i features
  • not a wordj andsss respectively represent the characteristics and values ​​of the best division point
  • c 1 c_1 c1and c 2 c_2c2Respectively represent the fixed output values ​​in the two regions after division.
  • L \mathcal{L}L represents the loss function

Assume that the input space is partitioned into MMM units:R 1 , R 2 , . . . , R m R_1, R_2, ..., R_mR1,R2,...,Rm, then the output value of each region is cm = avg ( yi ∣ xi ∈ R m ) c_m = \mathrm{avg}(y_i|x_i \in R_m)cm=avg ( andixiRm) , that is, all pointsyyThe mean of the y values.

in:

  • M M M represents the number of units the input space is divided into
  • R 1 , R 2 , . . . , R m R_1, R_2, ..., R_m R1,R2,...,Rmrepresents each unit
  • c m c_m cmIndicates the output value of each area, which is equal to all points yy in the areamean of y -values
  • a v g ( y i ∣ x i ∈ R m ) \mathrm{avg}(y_i|x_i \in R_m) avg ( andixiRm) means that inxi x_ixiBelongs to the region R m R_mRmUnder the condition that all yi y_iyithe average of the values.

Q1 : What is a "unit"?
A1 : In a decision tree, a unit (also called a region) is a subregion into which the input space is divided. The decision tree divides the input space into several units by continuously selecting the best division point, and the data points in each unit have similar characteristics. Each cell has a fixed output value that predicts the target value for the data points within that region.

Q2 : Is a partition node a node? Can partition nodes be leaf nodes?
A2 : The partition node refers to the non-leaf node in the decision tree, which is used to divide the input space into several sub-regions. Each partition node has a partition condition, which is used to determine which subregion the data point belongs to. The partition node is not a leaf node. A leaf node refers to a node without child nodes in the decision tree. It represents a unit and is used to predict the target value of the data points in the area.

Q3 : Unit = leaf node, divide the node of the non-leaf node, right?
A3 : Yes. In a decision tree, each unit corresponds to a leaf node, and each leaf node represents a unit. A partition node is a non-leaf node, which is used to divide the input space into several sub-regions.


Example : As shown in the figure below, if we want to regress the age of the residents in the building, divide the building into 3 areas R 1 , R 2 , R 3 R_1, R_2, R_3R1,R2,R3(red line). Then R 1 R_1R1The output is the average of the ages of the four residents in the first column, R 2 R_2R2The output is the average of the ages of the four residents in the second column, R 3 R_3R3The output is the average of the ages of the eight residents in the third and fourth columns.

insert image description here

7.2 Algorithm description

Input : training dataset DDD
output: regression treef ( x ) f(x)f(x)

In the input space where the training data set is located, recursively divide each area into two sub-areas and determine the output value on each sub-area, and construct a binary decision tree:

1. Select the optimal segmentation feature jjj and segmentation pointsss , solve for

min ⁡ j , s [ min ⁡ c 1 ∑ x i ∈ R 1 ( j , s ) ( y i − c 1 ) 2 + min ⁡ c 2 ∑ x i ∈ R 2 ( j , s ) ( y i − c 2 ) 2 ] \underset{j, s}{\min}\left[ \underset{c_1}{\min} \sum_{x_i \in R_1(j, s)} (y_i - c_1)^2 + \underset{c_2}{\min} \sum_{x_i \in R_2(j, s)}(y_i - c_2)^2 \right] j,smin c1minxiR1(j,s)(yic1)2+c2minxiR2(j,s)(yic2)2

traverse feature jjj , for a fixed segmentation featurejjj scan split pointsss , choose the pair ( j , s ) (j, s)that makes the above formula reach the minimum value(j,s)

Second , with the selected pair ( j , s ) (j, s)(j,s ) to divide the area and determine the corresponding output value:

R 1 ( j , s ) = x ∣ x ( j ) ≤ s R_1(j, s) = x|x^{(j)} \le sR1(j,s)=xx(j)s

R 2 ( j , s ) = x ∣ x ( j ) > s R_2(j, s) = x|x^{(j)} >s R2(j,s)=xx(j)>s

c ^ m = 1 N ∑ x 1 ∈ R m ( j , s ) y i   其中 x ∈ R m , m = 1 , 2 \hat{c}_m = \frac{1}{N}\sum_{x_1 \in R_m(j, s)} y_i \ \ 其中x \in R_m, m = 1, 2 c^m=N1x1Rm(j,s)yi  where xRm,m=1,2

3. Continue to call steps 1 and 2 for the two sub-regions until the stop condition is met.

4. Divide the input space into MMM regionsR 1 , R 2 , . . . , RM R_1, R_2, ..., R_MR1,R2,...,RM, to generate a decision tree:

f ( x ) = ∑ m = 1 M c ^ m I ( x ∈ R m ) f(x) = \sum_{m = 1}^M \hat{c}_mI(x\in R_m) f(x)=m=1Mc^mI(xRm)

in:

  • D D D represents the training data set
  • f ( x ) f(x) f ( x ) represents the regression tree
  • not a wordj andsss respectively represent the optimal segmentation feature and segmentation point
  • R 1 ( j , s ) R_1(j, s) R1(j,s ) andR 2 ( j , s ) R_2(j, s)R2(j,s ) respectively represent the two sub-regions divided according to the optimal segmentation features and segmentation points
  • c 1 c_1 c1and c 2 c_2c2represent the output values ​​in the two subregions respectively
  • c ^ m \hat{c}_m c^mIndicates mmThe output value in the m area, which is equal to all points yymean of y -values
  • M M M represents the number of regions into which the input space is divided
  • R 1 , R 2 , . . . , R M R_1, R_2, ..., R_M R1,R2,...,RMIndicates each region.

7.3 Simple example

For ease of understanding, a simple example will be used to deepen the understanding of regression decision tree. The training data is shown in the table below, and our goal is to obtain a least squares regression tree.

x x x (eigenvalue) 1 2 3 4 5 6 7 8 9 10
yyy (target value) 5.56 5.7 5.91 6.4 6.8 7.05 8.9 8.7 9 9.05

7.3.1 Example calculation process

1. Select the optimal segmentation feature jjj and the optimal segmentation pointsss

  • Determine the first problem: choose the optimal segmentation feature
    • In this data set, there is only one feature, so the optimal segmentation feature is naturally xxx
  • Determine the second problem: we consider 9 cut points [ 1.5 , 2.5 , 3.5 , 4.5 , 5.5 , 6.5 , 7.5 , 8.5 , 9.5 ] [1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.5, 9.5 ][1.5,2.5,3.5,4.5,5.5,6.5,7.5,8.5,9.5]
    • The loss function is defined as a square loss function L ( y , f ( x ) ) = [ f ( x ) − y ] 2 \mathcal{L}(y, f(x)) = [f(x) - y]^2L ( y ,f(x))=[f(x)y]2 , wheref ( x ) f(x)f ( x ) is the predicted value,yyy is the real value (target value)
    • Substitute the above nine segmentation points into the following formula, where cm = avg ( yi ∣ xi ∈ R m ) c_m = \mathrm{avg}(y_i | x_i \in R_m)cm=avg ( andixiRm)

a. Calculate the sub-region output value:

When the cut point s = 1.5 s=1.5s=1.5 , the data is divided into two sub-regions:R 1 R_1R1and R 2 R_2R2 R 1 R_1 R1Include eigenvalues ​​1 11 data point, andR 2 R_2R2Include eigenvalues ​​2 , 3 , 4 , 5 , 6 , 7 , 8 , 9 , 10 2,3,4,5,6,7,8,9,102,3,4,5,6,7,8,9,10 data points.

c 1 c_1 c1and c 2 c_2c2are the output values ​​of these two subregions, respectively. They are calculated by adding the target values ​​within their respective subregions and dividing by the number of data points. Therefore, the output values ​​for these two regions are:

  • c 1 = 5.56 c_1 = 5.56 c1=5.56
  • c 2 = 5.7 + 5.91 + 6.4 + 6.8 + 7.05 + 8.9 + 8.7 + 9 + 9.05 9 = 7.50 c_2= \frac{5.7+5.91+6.4+6.8+7.05+8.9+8.7+9+9.05}{9} = 7.50 c2=95.7+5.91+6.4+6.8+7.05+8.9+8.7+9+9.05=7.50

When the cut point s = 2.5 s=2.5s=2.5 , the data is divided into two sub-regions:R 1 R_1R1and R 2 R_2R2 R 1 R_1 R1Including eigenvalues ​​1 , 2 1,21,2 data points, andR 2 R_2R2Include eigenvalues ​​3 , 4 , 5 , 6 , 7 , 8 , 9 , 10 3,4,5,6,7,8,9,103,4,5,6,7,8,9,10 data points.

c 1 c_1 c1and c 2 c_2c2are the output values ​​of these two subregions, respectively. They are calculated by adding the target values ​​within their respective subregions and dividing by the number of data points. Therefore, the output values ​​for these two regions are:

  • c 1 = 5.56 + 5.7 2 = 5.63 c_1 = \frac{5.56 + 5.7}{2} = 5.63 c1=25.56+5.7=5.63

  • c 2 = 5.91 + 6.4 + 6.8 + 7.05 + 8.9 + 8.7 + 9 + 9.05 8 = 7.73 c_2 = \frac{5.91+6.4+6.8+7.05+8.9+8.7+9+9.05}{8} = 7.73 c2=85.91+6.4+6.8+7.05+8.9+8.7+9+9.05=7.73

Similarly, we can get the sub-region output values ​​of other segmentation points, as shown in the following table:

s s s 1.5 2.5 3.5 4.5 5.5 6.5 7.5 8.5 9.5
c 1 c1 c 1 5.56 5.63 5.72 5.89 6.07 6.24 6.62 6.88 7.11
c 2 c2 c2 7.5 7.73 7.99 8.25 8.54 8.91 8.92 9.03 9.05

b. Calculate the loss function value and find the optimal segmentation point:

put c 1 c_1c1, c 2 c_2 c2The value of is substituted into the square loss function L ( y , f ( x ) ) = [ f ( x ) − y ] 2 \mathcal{L}(y, f(x)) = [f(x) - y]^2L ( y ,f(x))=[f(x)y]2 , wheref ( x ) f(x)f ( x ) is the predicted value,yyy is the real value (target value)

When s=1.5: the total loss is:

L = ∑ x i ∈ R 1 [ f ( x i ) − y i ] 2 + ∑ x i ∈ R 2 [ f ( x i ) − y i ] 2 = [ 5.56 − 5.56 ] 2 + [ 7.50 − 5.7 ] 2 + [ 7.50 − 5.91 ] 2 + . . . + [ 7.50 − 9.05 ] 2 = 0 + ( 1.8 ) 2 + ( 1.59 ) 2 + . . . + ( − 1.55 ) 2 = 15.72 \begin{aligned} \mathcal{L} &= \sum_{x_i \in R_1} [f(x_i) - y_i]^2 + \sum_{x_i \in R_2} [f(x_i) - y_i]^2 \\ &= [5.56 - 5.56]^2 + [7.50 - 5.7]^2 + [7.50 - 5.91]^2 + ... + [7.50 - 9.05]^2 \\ &= 0 + (1.8)^2 + (1.59)^2 + ... + (-1.55)^2 & = 15.72 \end{aligned} L=xiR1[f(xi)yi]2+xiR2[f(xi)yi]2=[5.565.56]2+[7.505.7]2+[7.505.91]2+...+[7.509.05]2=0+(1.8)2+(1.59)2+...+(1.55)2=15.72

When the cut point s = 2.5 s=2.5s=2.5 , the total loss is:

L = ∑ x i ∈ R 1 [ f ( x i ) − y i ] 2 + ∑ x i ∈ R 2 [ f ( x i ) − y i ] 2 = [ 5.63 − 5.56 ] 2 + [ 5.63 − 5.7 ] 2 + [ 7.73 − 5.91 ] 2 + . . . + [ 7.73 − 9.05 ] 2 = ( 0.07 ) 2 + ( − 0.07 ) 2 + ( 1.82 ) 2 + . . . + ( − 1.32 ) 2 \begin{aligned} \mathcal{L} &= \sum_{x_i \in R_1} [f(x_i) - y_i]^2 + \sum_{x_i \in R_2} [f(x_i) - y_i]^2 \\ &= [5.63 - 5.56]^2 + [5.63 - 5.7]^2 + [7.73 - 5.91]^2 + ... + [7.73 - 9.05]^2 \\ &= (0.07)^2 + (-0.07)^2 + (1.82)^2 + ... + (-1.32)^2 \end{aligned} L=xiR1[f(xi)yi]2+xiR2[f(xi)yi]2=[5.635.56]2+[5.635.7]2+[7.735.91]2+...+[7.739.05]2=(0.07)2+(0.07)2+(1.82)2+...+(1.32)2

In the same way, the loss function values ​​of other segmentation points are calculated, and the following table can be obtained:

s s s 1.5 2.5 3.5 4.5 5.5 6.5 7.5 8.5 9.5
m ( s ) m(s) m(s) 15.72 12.07 8.36 5.78 3.91 1.93 8.01 11.73 15.74

Obviously take s = 6.5 s=6.5s=6.5 ,m ( s ) m(s)m ( s ) is the smallest. So the first division variable[ j = x , s = 6.5 ] [j=x,s=6.5][j=x,s=6.5]


Q : Why use m ( s ) m(s)m ( s ) should not beL ( y , f ( x ) ) \mathcal{L}(y, f(x))L ( y ,f ( x )) ?
A:m ( s ) m(s)m ( s ) andL \mathcal{L}L both represent the loss function. In a regression decision tree, the loss function is used to measure the difference between the predicted value and the true value in the divided subregion. Different documents or materials may use different symbols to represent the loss function, but their meanings are the same.

m ( s ) m(s) m ( s ) is used to indicate that at the split pointssThe loss function value at s . Therefore, when calculating the loss function value at different split points, usem ( s ) m(s)m ( s ) orL \mathcal{L}L is all possible.


Second , use the selected ( j , s ) (j, s)(j,s ) to divide the area and determine the output value:

  • The two regions are: R 1 = 1 , 2 , 3 , 4 , 5 , 6 R_1={1,2,3,4,5,6}R1=1,2,3,4,5,6 R 2 = 7 , 8 , 9 , 10 R_2={7,8,9,10} R2=7,8,9,10
  • 输出值 c m = a v g ( y i ∣ x i ∈ R m ) c_m = \mathrm{avg}(y_i|x_i\in R_m) cm=avg ( andixiRm) c 1 = 6.24 c_1 =6.24 c1=6.24 c 2 = 8.91 c_2 = 8.91 c2=8.91

3. Call steps 1 and 2 and continue to divide:

pair R 1 R_1R1Continue to divide:

x x x (eigenvalue) 1 2 3 4 5 6
yyy (target value) 5.56 5.7 5.91 6.4 6.8 7.05

Take cut points [ 1.5 , 2.5 , 3.5 , 4.5 , 5.5 ] [1.5,2.5,3.5,4.5,5.5][1.5,2.5,3.5,4.5,5.5 ] , then the output value ccof each areac is as follows:

s s s 1.5 2.5 3.5 4.5 5.5
c 1 c1 c 1 5.56 5.63 5.72 5.89 6.07
c 2 c2 c2 6.37 6.54 6.75 6.93 7.02

Calculate the loss function value m ( s ) m(s)m(s)

s ss 1.5 2.5 3.5 4.5 5.5
m ( s ) m(s) m(s) 1.3087 0.754 0.2771 0.4368 1.0644

s = 3.5 s=3.5 s=3.5 ,m ( s ) m(s)m ( s ) is the smallest.

cycle…

The division termination conditions of the regression decision tree usually have the following types:

  1. The number of data points in the subregion is less than a preset threshold.
  2. The variance of the target value of the data points in the subregion is less than a pre-set threshold.
  3. The depth of the tree reaches a preset maximum depth.

When any one of the above conditions is satisfied, the division process will be terminated. These conditions can be tuned to the specific problem to obtain the best model performance.

Fourth , generate a regression tree

Assuming that the division is stopped after generating 3 regions, the form of the final generated regression tree is as follows:

T = { 5.72 x ≤ 3.5 6.75 3.5 ≤ x ≤ 6.5 8.91 6.5 < x T = \begin{cases} 5.72 & x \le 3.5 \\ 6.75 & 3.5 \le x \le 6.5 \\ 8.91 & 6.5 < x \end{cases} T= 5.726.758.91x3.53.5x6.56.5<x

The structure of this regression tree is as follows:

          [j=x,s=6.5]
          /         \
  [j=x,s=3.5]       R_2
   /       \
 R_{11}   R_{12}

where, R 11 R_{11}R11 R 12 R_{12} R12and R 2 R_2R2All are leaf nodes.

This regression tree has three leaf nodes, corresponding to three sub-regions R 11 R_{11}R11 R 12 R_{12} R12and R 2 R_2R2. The division variable of the root node is [ j = x , s = 6.5 ] [j=x,s=6.5][j=x,s=6.5 ] , which divides the data into two subregions:R 1 R_1R1and R 2 R_2R2. The left child node of the root node corresponds to the sub-region R 1 R_1R1, its division variable is [ j = x , s = 3.5 ] [j=x,s=3.5][j=x,s=3.5 ] , subregionR 1 R_1R1Divide again into two subregions: R 11 R_{11}R11and R 12 R_{12}R12. The left and right child nodes of the left child node of the root node respectively correspond to the sub-region R 11 R_{11}R11and R 12 R_{12}R12, they are all leaf nodes. The right child node of the root node corresponds to the sub-region R 2 R_2R2, which is also a leaf node.

in:

  • not a wordj andsss represent the segmentation feature and segmentation point respectively
  • j = x j=x j=x indicates that the segmentation feature isxxx , ands = 6.5 s=6.5s=6.5 means the segmentation point is6.5 6.56.5
    • When the segmentation variable is [ j = x , s = 6.5 ] [j=x,s=6.5][j=x,s=6.5 ] , the data will be based on featurexxThe values ​​of x are divided into two subregions:R 1 R_1R1and R 2 R_2R2. Subregion R 1 R_1R1Include eigenvalues ​​less than or equal to 6.5 6.56.5 data points, and subregionR 2 R_2R2Include eigenvalues ​​greater than 6.5 6.56.5 data points.
    • Therefore, when the segmentation variable is [ j = x , s = 6.5 ] [j=x,s=6.5][j=x,s=6.5 ] , the data will be based on featurexxThe values ​​of x are divided into two subregions.

Summary :

  • Input: training dataset DDD
  • Output: regression tree f ( x ) f(x)f(x)
  • Process: In the input space where the training data set is located, recursively divide each area into two sub-areas and determine the output value on each sub-area, and construct a binary decision tree:
    1. Select the optimal segmentation feature jjj and segmentation pointsss,求解 min ⁡ j , s [ min ⁡ c 1 ∑ x i ∈ R 1 ( j , s ) ( y i − c 1 ) 2 + min ⁡ c 2 ∑ x i ∈ R 2 ( j , s ) ( y i − c 2 ) 2 ] \underset{j, s}{\min}\left[ \underset{c_1}{\min} \sum_{x_i \in R_1(j, s)} (y_i - c_1)^2 + \underset{c_2}{\min} \sum_{x_i \in R_2(j, s)}(y_i - c_2)^2 \right] j,smin[c1minxiR1(j,s)(yic1)2+c2minxiR2(j,s)(yic2)2 ]—— traverse featurejjj , for a fixed segmentation featurejjj scan split pointsss , choose the pair ( j , s ) (j, s)that makes the above formula reach the minimum value(j,s)
    2. With the selected pair ( j , s ) (j, s)(j,s ) to divide the area and determine the corresponding output value:
      R 1 ( j , s ) = x ∣ x ( j ) ≤ s R_1(j, s) = x|x^{(j)} \le sR1(j,s)=xx(j)s

R 2 ( j , s ) = x ∣ x ( j ) > s R_2(j, s) = x|x^{(j)} >s R2(j,s)=xx(j)>s

c ^ m = 1 N ∑ x 1 ∈ R m ( j , s ) y i   其中 x ∈ R m , m = 1 , 2 \hat{c}_m = \frac{1}{N}\sum_{x_1 \in R_m(j, s)} y_i \ \ 其中x \in R_m, m = 1, 2 c^m=N1x1Rm(j,s)yi  where xRm,m=1,2
3. Continue calling steps one and two for both subregions until the stop condition is met.
4. Divide the input space intoMMM regionsR 1 , R 2 , … , RM R_1, R_2 , …, R_MR1,R2,,RM, generate a decision tree f ( x ) = ∑ m = 1 M c ^ m I ( x ∈ R m ) f(x) = \sum_{m = 1}^M \hat{c}_mI(x\in R_m)f(x)=m=1Mc^mI(xRm)

7.4 Comparison of Regression Decision Tree and Linear Regression

import numpy as np 
import matplotlib.pyplot as plt 
from sklearn.tree import DecisionTreeRegressor 
from sklearn.linear_model import LinearRegression
from pylab import mpl
# 设置中文字体
mpl.rcParams["font.sans-serif"] = ["SimHei"]
# 设置正常显示符号
mpl.rcParams["axes.unicode_minus"] = False


# 1. ⽣成数据 
x = np.array(list(range(1, 11))).reshape(-1, 1)  # 使其变为列向量
y = np.array([5.56, 5.70, 5.91, 6.40, 6.80, 7.05, 8.90, 8.70, 9.00, 9.05]) 

# 2. 训练模型 
model_1 = DecisionTreeRegressor(max_depth=1)  # 决策树模型
model_2 = DecisionTreeRegressor(max_depth=3)  # 决策树模型
model_3 = LinearRegression()  # 线性回归模型
model_1.fit(x, y) 
model_2.fit(x, y) 
model_3.fit(x, y) 

# 3. 模型预测 
X_test = np.arange(0.0, 10.0, 0.01).reshape(-1, 1)  # ⽣成1000个数,⽤于预测模型 
predict_1 = model_1.predict(X_test) 
predict_2 = model_2.predict(X_test) 
predict_3 = model_3.predict(X_test) 

# 4. 结果可视化 
plt.figure(dpi=300) 
plt.scatter(x, y, label="原始数据(目标值)") 
plt.plot(X_test, predict_1, label="回归决策树: max_depth=1") 
plt.plot(X_test, predict_2, label="回归决策树: max_depth=3") 
plt.plot(X_test, predict_3, label="线性回归") 

plt.xlabel("数据") 
plt.ylabel("预测值") 
plt.title("线性回归与回归决策树效果对比")
plt.grid(alpha=0.5)

plt.legend() 
plt.show()

result:

insert image description here

8. Decision tree summary

8.1 Advantages

  1. Easy to understand and explain .
    • The structure of the decision tree can be visualized and easily understood by non-experts.
  2. Data preparation is simple .
    • Decision trees do not require complex preprocessing of the data, such as normalizing or removing missing values.
  3. Ability to handle both numerical and categorical data .
  4. Not affected by data scaling .
  5. Computational cost is relatively low .

These advantages make decision trees widely used in many fields.

8.2 Disadvantages

  1. prone to overfitting . Decision tree models can produce overly complex models that generalize poorly.
    • Overfitting can be avoided by pruning, setting the minimum number of samples required for a leaf node, or setting the maximum depth of the tree.
  2. instability . Small data changes can result in completely different trees.
    • This problem can be alleviated by decision tree ensembles.
  3. Prediction is difficult for continuous fields .
  4. When there are too many categories, the error rate may increase faster .

These disadvantages need to be aware of when using decision trees.

8.3 Improved method

For the shortcomings of decision trees, there are some improvement methods that can be used. For example:

  1. Avoid overfitting .
    • Overfitting can be avoided by pruning, setting the minimum number of samples required for a leaf node, or setting the maximum depth of the tree.
    • Pruning includes pre-pruning and post-pruning.
      • The former controls the depth of the tree or the number of nodes by setting thresholds for continuous variables, and operates before the nodes start to divide, thereby preventing overfitting.
      • The latter is to examine non-leaf nodes from the bottom up. If the internal node is replaced by a leaf node to improve the generalization ability of the decision tree, then replace it.
  2. Ensemble using decision trees .
    • The stability and accuracy of the model can be improved by integrating multiple decision trees.
      • For example, the random forest algorithm is an ensemble learning algorithm based on decision trees, which improves the accuracy and stability of the model by building multiple decision trees and combining their prediction results.
  3. Discretize continuous fields .
    • Continuous fields can be discretized into categorical variables so that decision trees can handle them better.
  4. Resample class-imbalanced data .
    • Class-imbalanced data can be resampled to reduce the error rate.

These methods can help improve decision tree models, increasing their accuracy and stability.

Guess you like

Origin blog.csdn.net/weixin_44878336/article/details/130798774