Data Science: How is the decision tree in Sklearn designed and stored at the bottom?

guide

I was doing some pre-research work on machine learning in the early stage, and reproduced the algorithm of a paper on migrating random forests. It is necessary to inherit and extend the API of the decision tree in sklearn, which requires understanding how the bottom layer of the decision tree is designed. and achieved. This article briefly introduces and shares around this detail.

d8b3f121902d04e6703c432ba11c546f.png

Decision tree is a classic machine learning algorithm. It has undergone several major version iterations such as ID3, C4.5, and CART. The implementation of decision tree built in sklearn is mainly based on the standard CART tree, but there are differences in some principle details. , for the algorithm principle of decision tree, please refer to the historical article: Fun! 5000 words explain the basic principles of decision tree in a simple way . Decision trees can be used for both classification and regression. At the same time, they constitute the foundation of many integrated algorithms, so they play an important role in the field of machine learning. For integrated algorithms, you can refer to historical articles: A picture introduces machine learning . Ensemble learning algorithm .


In order to explore how the decision tree in sklearn is designed and implemented, take the classification decision tree as an example, first look at which attributes and interfaces are built into the decision tree: use the dir attribute to check which attributes are included in an initial decision tree (here The attributes starting with "_" are filtered out, because they are generally built-in private attributes), and the results are as follows:

476fd178c6f55f0fbd7af2859b15d3be.png

The above-mentioned interfaces are mainly divided into two categories: attributes and functions (this seems to be nonsense: everyone who understands the definition of classes in programming languages ​​knows that classes mainly include attributes and functions, where attributes correspond to values, and functions correspond to functions accomplish). If you need to specifically distinguish which are attributes and which are functions, you can use the auto-completion function in the ipython interpreter.

Roughly browsing the above results, the attributes are mainly the parameters when the decision tree is initialized, such as ccp_alpha: pruning coefficient, class_weight: the weight of the class, criterion: the split criterion, etc.; there are also the main functions implemented by the decision tree, such as fit: model training, predict: model prediction and so on.

The focus of this article is to explore how to save the trained "tree" in the decision tree, so we further use the iris data set to train the decision tree, and then call the dir function again to see what attributes and interfaces have been added:

a1f429c2765789846aca6adbe1f2801a.png


Through the difference set of the set, it is obvious that the decision tree before and after training mainly adds 6 attributes (all attributes, not functions), and it is easy to infer its meaning through the attribute name:

  • classes_: the value of the classification label, that is, the unique value set of y

  • max_features_: maximum number of features

  • n_classes_: the number of categories, such as 2 categories or multi-categories, etc., that is, the length in the classes_ attribute

  • n_features_in_: The number of input features, equivalent to n_features_ in the old version of sklearn, has been deprecated, and n_features_in_ is recommended

  • n_outputs: The number of multiple outputs, that is, the decision tree can not only be used to implement a single classification problem, but also multiple classification problems at the same time, for example, given a set of character characteristics, it is used to simultaneously judge whether it is male/female, fat/ Thin and tall, these are 3 classification problems, that is, 3 outputs (you need to distinguish between multi-classification and multi-output tasks)

  • tree_: There is no doubt that this tree_ is the focus of this article today. It is a new attribute set after the decision tree training, which stores how the decision tree is stored.

Then let's further explore the tree_ attribute, first print the tree_ attribute and find that this is a Tree object, and gives the file path in sklearn:

b4cbfd2b25f446575d9a48a9821883ac.png

We can view the introduction of the Tree class through the help method:

56d14d79b611937af4f68269e3ca3afd.png

Through the above doc document, the first sentence clearly describes the decision tree as follows:

Array-based representation of a binary decision tree.

That is: a binary classification decision tree based on array representation, that is, a binary tree! Further, in this binary tree, the i-th element of the array represents the information of the i-th node of the decision tree, and node 0 represents the root node of the decision tree. So what information does each node contain? We noticed that the file name of the node is listed in the above document: _tree.pxd. Looking at it, it is easy to find that the definition of the node is as follows:

24ef01fda7b01a87691f930aa89905c8.png

Although it is the definition syntax of cython, it is not difficult to deduce the type and meaning of its attribute fields, for example:

  • left_child: size type (unsigned integer), representing the index of the left child node of the current node

  • right_child: similar to left_child

  • feature: size type, which represents the feature index used by the current node for splitting, that is, which column of features is used for splitting in the training set

  • threshold: double type, which represents the splitting threshold when the current node selects the corresponding feature. Generally, when it is ≤ this threshold, it enters the left child node, otherwise it enters the right child node

  • n_node_samples: size type, representing the total number of samples falling into the node during training. Obviously, the n_node_samples of a parent node will be equal to the sum of n_node_samples of its left and right child nodes.

So far, the attribute definition and implementation of a single node in the decision tree are basically inferred, so how does the entire decision tree connect all the nodes together? We resort to the tree_ attribute of the trained decision tree again to see what interfaces it has, and still filter out the built-in private attributes, and get the following results:

a3153145177d952bff29df3690114eed.png

Of course, you can also use the auto-completion function of the ipython interpreter to further check whether each interface is an attribute or a function:

0949225c867cd84fddbaf0b192c9f5c4.png

Many of these attributes have been mentioned in the previous explanation of the node definition. Here we need to focus on the following attribute values:

  • node_count: the total number of nodes in the decision tree

  • children_left: array of left children of each node

  • children_right: array of right children of each node

  • feature: Each node selects a split feature index array

  • threshold: Each node selects a split feature threshold array

  • value: statistics of the number of samples of each type falling into each node

  • n_leaves: total number of leaf nodes

Probably the more important ones are these! In order to further understand how the data in each attribute is stored, we still take the iris data set as an example, train a decision tree with max_depth=2 (the root node corresponds to depth=0), and check the following values:

c5ca44aa2fc17819a2ea775283f27690.png

It can be seen that:

  • The trained decision tree contains 5 nodes, including 3 leaf nodes

  • Through the two attributes of children_left and children_right, we can know that the left child node index of the 0th node (that is, the root node) is 1, and the right child node index is 2; the left and right child nodes of the first node are both -1, which means This node is a leaf node; the left and right child nodes of the second node are 3 and 4 respectively, indicating that it is an internal node and further split

  • Through the two attributes of feature and threshold, we can know that the 0th node (root node) is split using the feature with index 3 (corresponding to the 4th column feature), and its optimal segmentation threshold is 0.8; because the 1st node is a leaf Node, so it is no longer split, and its corresponding feature and threshold fields are both -2

  • Through the value attribute, you can view the number of various samples falling into each node. Since the iris data set is a three-category problem, and the decision tree has 5 nodes, the value of value is a 5×3 two-dimensional Array, for example, the first line represents the sample count falling into the root node is [50, 50, 50], the second line represents the sample count falling into the left child node is [50, 0, 0], because it is already pure , so no further splits are made.

  • In addition, the label value corresponding to each leaf node is not directly marked in the tree, but it can be obtained through the value attribute, that is, the category that falls into the most samples in each leaf node is the corresponding label. It is even said that not only can the corresponding label be known, but also the corresponding probability can be obtained by calculating the ratio of the numbers!

Take the iris data set to manually verify the above conjecture, split the root node with the split feature 3 and the threshold 0.8, and get the count results of the samples falling into the left child node as follows. It is found that only 50 first-class samples are left after the split , that is, the sample count is [50, 0, 0], exactly the same.

643e0877150e83e586ef3bc0678dc378.png

In addition, through the correspondence between the child nodes of the two attributes children_left and children_right, we can actually infer that the traversal method of the binary tree is pre-order traversal, that is, in the order of root-left-right, for the above decision tree, it corresponds to the binary tree after splitting The schematic diagram is as follows:

4285c938082d5f6a20c06f981bcbc3a6.png

d158138809dde128f0fc94090f7d5811.png

Related Reading:

Guess you like

Origin blog.csdn.net/weixin_43841688/article/details/121646534