MNIST Digit Classification Based on Decision Tree

1. About the author

Hou Qingshan, male, School of Electronic Information, Xi'an Polytechnic University, 2021 graduate student
Research direction: smoke image segmentation
Email: [email protected]

Liu Shuaibo, male, School of Electronic Information, Xi'an Polytechnic University, 2021 graduate student, Zhang Hongwei Artificial Intelligence Research Group
Research direction: Machine Vision and Artificial Intelligence
Email: [email protected]

2. Introduction to knowledge about theory

2.1 Introduction to the principle of decision tree

A decision tree is the process of classifying data through a series of rules. It provides a rule-like way of getting what value under what conditions. Decision trees are divided into classification trees and regression trees. Classification trees are used as decision trees for discrete variables, and regression trees are used as decision trees for continuous variables.

Recent surveys have shown that decision trees are also the most frequently used data mining algorithm, and their concepts are very simple. A very important reason why the decision tree algorithm is so popular is that users basically do not need to understand the machine learning algorithm, and do not need to delve into how it works. Intuitively, a decision tree classifier is like a flow chart composed of a judgment module and a termination block, and the termination block represents the classification result (that is, the leaves of the tree). The judgment module represents the judgment on the value of a feature (the feature has several values, and the judgment module has several branches).

If efficiency, etc. are not considered, the cascaded judgments of all the features of the sample will eventually assign a certain sample to a class termination block. In fact, some of the features of the sample play a decisive role in the classification. The construction process of the decision tree is to find these decisive features, and construct an inverted tree according to their degree of determinism – the feature with the most decisive role is used as the The root node, and then recursively find the next largest decisive feature in the sub-dataset under each branch, until all the data in the sub-dataset belong to the same class. Therefore, the process of constructing a decision tree is essentially a recursive process of classifying data sets according to data features. The first problem we need to solve is which feature on the current data set plays a decisive role in classifying the data.
   insert image description here

Figure 1 Watermelon decision tree

2.2 Generation process of decision tree

The generation process of a decision tree is mainly divided into the following three parts:

  • Feature selection : Feature selection refers to selecting a feature from the many features in the training data as the splitting criterion for the current node. There are many different quantitative evaluation criteria for how to select a feature, thereby deriving different decision tree algorithms.

  • Decision tree generation : According to the selected feature evaluation criteria, recursively generate child nodes from top to bottom, and stop the decision tree from growing until the data set is inseparable. In terms of tree structure, recursive structure is the easiest way to understand.

  • Pruning : Decision trees are prone to overfitting. Generally, pruning is required to reduce the size of the tree structure and alleviate overfitting. There are two types of pruning techniques: pre-pruning and post-pruning.

2.3 Three Decision Trees Based on Information Theory

​ The biggest principle of dividing data sets is to make unordered data orderly. If there are 20 features in a training data, which one is selected as the basis for division? This must be judged by quantitative methods. There are multiple quantitative division methods, one of which is "information theory metric information classification". Decision tree algorithms based on information theory include ID3, CART and C4.5 algorithms, of which C4.5 and CART are derived from the ID3 algorithm.

​ CART and C4.5 support the processing when the data features are continuously distributed, mainly by using binary segmentation to process continuous variables, that is, to find a specific value - split value: if the eigenvalue is greater than the split value, go to the left subtree, Or just walk the right subtree. The principle of selecting this splitting value is to reduce the "degree of confusion" in the divided subtree. Specifically, the C4.5 and CART algorithms have different definitions.

The ID3 algorithm was invented by Ross Quinlan and is based on "Occam's Razor": the smaller the decision tree, the better the larger decision tree (be simple simple theory). In the ID3 algorithm, according to the information gain evaluation and selection feature of information theory , each time the feature with the largest information gain is selected as the judgment module. The ID3 algorithm can be used to divide nominal data sets. There is no pruning process. In order to remove the problem of excessive data matching, adjacent leaf nodes that cannot generate a large amount of information gain can be merged by pruning (for example, setting the information gain threshold). The disadvantage of using information gain is that it is biased towards attributes with a large number of values ​​- that is, in the training set, the more different values ​​an attribute takes, the more likely it is to be used as a split attribute, and doing so is sometimes meaningless. In addition, ID3 cannot handle continuously distributed data features, so there is a C4.5 algorithm. The CART algorithm also supports continuously distributed data features.

C4.5 is an improved algorithm of ID3, which inherits the advantages of ID3 algorithm. The C4.5 algorithm uses the information gain rate to select attributes , which overcomes the deficiencies of choosing attributes with many values ​​when using information gain to select attributes. It performs pruning in the process of tree construction; it can complete the discretization of continuous attributes; Incomplete data are processed. The classification rules generated by the C4.5 algorithm are easy to understand and have high accuracy; however, the efficiency is low, because the data set needs to be scanned and sorted multiple times in the process of tree construction. Also because multiple dataset scans are necessary, C4.5 is only suitable for datasets that can reside in memory.

The full name of the CART algorithm is Classification And Regression Tree. It uses the Gini index (select the feature s with the smallest Gini index) as the splitting standard, and it also includes post-pruning operations. Although the ID3 algorithm and the C4.5 algorithm can mine as much information as possible in the learning of the training sample set, the decision tree generated by them has larger branches and larger scales. In order to simplify the scale of the decision tree and improve the efficiency of generating the decision tree, the decision tree algorithm CART, which selects the test attributes according to the GINI coefficient, appears.

​ This article uses the ID3 algorithm, which calls the constructed decision tree sklearn.DecisionTreeClassifier(), we only need to adjust some parameters.

2.4 Decision tree and parameter interpretation

classifier =tree.DecisionTreeClassifier(criterion='entropy',                                                                                         splitter='random',
                                        max_depth=None,
                                        min_samples_split=3,
                                        min_samples_leaf=2,
                                        min_weight_fraction_leaf=0.0,
                                        max_features=None,
                                        random_state=None,
                                        max_leaf_nodes=None,
                                        min_impurity_decrease=0.0,
                                        min_impurity_split=None,class_weight=None,)

Figure 2 Detailed explanation of parameters

The standard yellow is some commonly used adjustment parameters

Link: sklearn.tree.DecisionTreeClassifier-scikit-learn Chinese Community

We use the decision tree constructed by ID3 here, so 'entropy' is selected in the criterion.

3. Classification of handwritten digits based on decision tree

3.1 Experimental code

##基于决策树的手写数字分类
from sklearn import tree
import numpy as np
from sklearn.datasets import load_digits


dataset = np.load('C:\\Users\\asus\\Desktop\\dateset\\mnist.npz')##获取数据集
x_train = dataset['x_train']#所有自变量,用于训练的自变量。
y_train = dataset['y_train']#这是训练数据的类别标签。
x_test = dataset['x_test']#这是剩下的数据部分,用来测试训练好的决策树的模型。
y_test = dataset['y_test']#这是测试数据的类别标签,用于区别实际类型和预测类型的准确性。

classifier = tree.DecisionTreeClassifier(criterion='entropy', splitter='random', max_depth=21, min_samples_split=3,random_state=40,)
#classifier = tree.DecisionTreeClassifier(criterion='entropy',splitter='random',max_depth=None,min_samples_split=3,min_samples_leaf=2,min_weight_fraction_leaf=0.0,max_features=None,random_state=None,max_leaf_nodes=None,min_impurity_decrease=0.0,min_impurity_split=None,class_weight=None,)

x_train = x_train.reshape(60000, 784)#第一个数字是用来索引图片,第二个数字是索引每张图片的像素点
x_test = x_test.reshape(10000, 784)

classifier.fit(x_train, y_train)
score = classifier.score(x_test, y_test)
print(score)
#tree.plot_tree(classifier)
#import matplotlib.pyplot as plt
#plt.show()
#后三行是画出决策树

3.2 Test process

The following are the parameters with the best results in my personal test

criterion='entropy', 
splitter='random', 
max_depth=21, 
min_samples_split=3,
random_state=40

insert image description here

Figure 3 max_depth debugging results

Figure 3 is the parameter debugging curve of max_depth. It can be seen that the correct rate of the test is between 86% and 88%. When max_depth is greater than 21, the
correct rate will be stable at 88.6% and no fluctuation will occur.
[External link image transfer failed, the source site may have anti-leech mechanism, it is recommended to save the image and upload it directly (img-tpO2Kacz-1647868130461) (C:\Users\asus\AppData\Roaming\Typora\typora-user-images\ image-20220321105438659.png)]

Figure 4 Decision tree test results

4. Pay attention

4.1 Introduction to the MNIST dataset

The MNIST dataset is from the National Institute of Standards and Technology

National Institute of Standards and Technology (NIST)

Includes 60,000 training images and 10,000 test images

The size of each picture is 28*28=784

insert image description here

Figure 5 Dataset image

The following packages need to be installed in the python3.6 environment

pip install sklearn

pip install numpy

pip install matplotlib#这个包是画图用的

4.2 MNIST dataset acquisition

4.2.1 Via function call

from sklearn.datasets import load_digits
mnist = load_digits()

4.2.2 Download to a local folder, and then call

I downloaded it here and called it locally

Download address 1: mnist.npz dataset is free - Programmer Sought

Download address 2: mnist data set download - mnist data set provides Baidu network disk download address_bigcindy's blog-CSDN blog_mnist data set

dataset = np.load('你的数据集位置')##获取数据集
#这里建议适用绝对路径位置
数据集下载——mnist数据集提供百度网盘下载地址_bigcindy的博客-CSDN博客_mnist数据集](https://blog.csdn.net/Jwenxue/article/details/89847251)

```python
dataset = np.load('你的数据集位置')##获取数据集
#这里建议适用绝对路径位置

Guess you like

Origin blog.csdn.net/m0_37758063/article/details/123646270