デシジョンツリーに基づくMNISTディジット分類

1.作者について

Hou Qingshan、男性、西安工科大学電子情報学部、2021年大学院生
研究の方向性:煙画像のセグメンテーション
Eメール:[email protected]

Liu Shuaibo、男性、西安工科大学電子情報学部、2021年大学院生、Zhang Hongwei人工知能研究グループ
研究の方向性:マシンビジョンと人工知能
Eメール:[email protected]

2.理論に関する知識の紹介

2.1ディシジョンツリーの原理の概要

デシジョンツリーは、一連のルールを通じてデータを分類するプロセスです。これは、どのような条件下でどのような値を取得するかのルールのような方法を提供します。決定木は分類木と回帰木に分けられます。分類木は離散変数の決定木として使用され、回帰木は連続変数の決定木として使用されます。

最近の調査によると、決定木も最も頻繁に使用されるデータマイニングアルゴリズムであり、その概念は非常に単純です。デシジョンツリーアルゴリズムが非常に人気がある非常に重要な理由は、ユーザーが基本的に機械学習アルゴリズムを理解する必要がなく、その仕組みを詳しく調べる必要がないことです。直感的には、決定木分類器は、判断モジュールと終了ブロックで構成されるフローチャートのようなものであり、終了ブロックは分類結果(つまり、ツリーの葉)を表します。判断モジュールは、機能の値に関する判断を表します(機能にはいくつかの値があり、判断モジュールにはいくつかのブランチがあります)。

効率などを考慮しない場合、サンプルのすべての機能のカスケード判断により、最終的に特定のサンプルがクラス終了ブロックに割り当てられます。実際、サンプルの特徴のいくつかは分類において決定的な役割を果たします。決定木の構築プロセスは、これらの決定的な特徴を見つけ、決定論の程度に従って逆ツリーを構築することです。これは、最も決定的な特徴です。ロールはルートノードとして使用され、サブデータセット内のすべてのデータが同じクラスに属するまで、各ブランチの下のサブデータセット内で次に大きい決定的な特徴を再帰的に見つけます。したがって、決定木を構築するプロセスは、本質的に、データの特徴に従ってデータセットを分類する再帰的なプロセスです。解決する必要がある最初の問題は、現在のデータセットのどの特徴がデータの分類において決定的な役割を果たすかです。
   ここに画像の説明を挿入

図1スイカの決定木

2.2決定木の生成プロセス

デシジョンツリーの生成プロセスは、主に次の3つの部分に分かれています。

  • 特徴選択:特徴選択とは、トレーニングデータ内の多くの特徴から、現在のノードの分割基準として特徴を選択することを指します。特徴の選択方法にはさまざまな定量的評価基準があり、それによってさまざまな決定木アルゴリズムが導き出されます。

  • デシジョンツリーの生成:選択した機能評価基準に従って、子ノードを上から下に再帰的に生成し、データセットが分離できなくなるまでデシジョンツリーの成長を停止します。ツリー構造に関しては、再帰構造が最も簡単に理解できる方法です。

  • 剪定:決定ツリーは過剰適合する傾向があります。一般に、剪定は、ツリー構造のサイズを縮小し、過剰適合を軽減するために必要です。剪定手法には、事前剪定と事後剪定の2種類があります。

2.3情報理論に基づく3つの決定木

データセットを分割する最大の原則は、順序付けされていないデータを整然と作成することです。トレーニングデータに20の特徴がある場合、分割の基礎としてどれが選択されますか?これは定量的な方法で判断する必要があります。複数の定量的な除算方法があり、その1つが「情報理論メトリック情報分類」です。情報理論に基づく決定木アルゴリズムには、ID3、CART、およびC4.5アルゴリズムが含まれ、そのうちC4.5およびCARTはID3アルゴリズムから派生しています。

CARTおよびC4.5は、データ機能が連続的に分散されている場合の処理​​をサポートします。主に、バイナリセグメンテーションを使用して連続変数を処理します。つまり、特定の値(分割値)を見つけます。固有値が分割値より大きい場合は、左側のサブツリーに移動するか、右側のサブツリーを歩きます。この分割値を選択する原則は、分割されたサブツリーの「混乱の程度」を減らすことです。具体的には、C4.5アルゴリズムとCARTアルゴリズムの定義は異なります。

ID3アルゴリズムはRossQuinlanによって発明され、「Occam's Razor」に基づいています。決定木が小さいほど、決定木は大きくなります(単純な理論になります)。ID3アルゴリズムでは、情報理論の情報ゲイン評価と選択機能に従って、毎回、情報ゲインが最大の機能が判断モジュールとして選択されます。ID3アルゴリズムを使用して、名目上のデータセットを分割できます。プルーニングプロセスはありません。過剰なデータマッチングの問題を取り除くために、大量の情報ゲインを生成できない隣接するリーフノードをプルーニングによってマージできます(たとえば、情報ゲインしきい値の設定)。情報ゲインを使用することの欠点は、値の数が多い属性に偏っていることです。つまり、トレーニングセットでは、属性が取る値が異なるほど、属性として使用される可能性が高くなります。属性を分割することは、意味がない場合があります。さらに、ID3は継続的に分散されるデータ機能を処理できないため、C4.5アルゴリズムがあります。CARTアルゴリズムは、継続的に分散されるデータ機能もサポートします。

C4.5はID3の改良されたアルゴリズムであり、ID3アルゴリズムの利点を継承しています。C4.5アルゴリズムは、情報ゲイン率を使用して属性を選択します。これにより、情報ゲインを使用して属性を選択するときに、多くの値を持つ属性を選択する際の欠点が克服されます。ツリー構築の過程でプルーニングが実行されます。連続属性;不完全なデータが処理されます。C4.5アルゴリズムによって生成された分類ルールは理解しやすく、精度が高いですが、ツリー構築の過程でデータセットを複数回スキャンして並べ替える必要があるため、効率は低くなります。また、複数のデータセットスキャンが必要なため、C4.5はメモリに常駐できるデータセットにのみ適しています。

CARTアルゴリズムのフルネームは、Classification And Regression Treeであり、分割標準としてGiniインデックス(最小のGiniインデックスを持つフィーチャを選択)を使用し、剪定後の操作も含まれます。ID3アルゴリズムとC4.5アルゴリズムは、トレーニングサンプルセットの学習で可能な限り多くの情報をマイニングできますが、それらによって生成される決定木は、より大きなブランチとより大きなスケールを持っています。デシジョンツリーのスケールを単純化し、デシジョンツリーの生成効率を向上させるために、GINI係数に従ってテスト属性を選択するデシジョンツリーアルゴリズムCARTが表示されます。

この記事では、構築された決定木sklearn.DecisionTreeClassifier()を呼び出すID3アルゴリズムを使用しています。調整する必要があるのは、一部のパラメーターのみです。

2.4決定木とパラメータの解釈

classifier =tree.DecisionTreeClassifier(criterion='entropy',                                                                                         splitter='random',
                                        max_depth=None,
                                        min_samples_split=3,
                                        min_samples_leaf=2,
                                        min_weight_fraction_leaf=0.0,
                                        max_features=None,
                                        random_state=None,
                                        max_leaf_nodes=None,
                                        min_impurity_decrease=0.0,
                                        min_impurity_split=None,class_weight=None,)

図2パラメータの詳細な説明

標準の黄色は、一般的に使用される調整パラメーターです。

リンク:sklearn.tree.DecisionTreeClassifier-scikit-learn中国語コミュニティ

ここではID3によって構築された決定木を使用するため、基準で「エントロピー」が選択されます。

3.決定木に基づく手書き数字の分類

3.1実験コード

##基于决策树的手写数字分类
from sklearn import tree
import numpy as np
from sklearn.datasets import load_digits


dataset = np.load('C:\\Users\\asus\\Desktop\\dateset\\mnist.npz')##获取数据集
x_train = dataset['x_train']#所有自变量,用于训练的自变量。
y_train = dataset['y_train']#这是训练数据的类别标签。
x_test = dataset['x_test']#这是剩下的数据部分,用来测试训练好的决策树的模型。
y_test = dataset['y_test']#这是测试数据的类别标签,用于区别实际类型和预测类型的准确性。

classifier = tree.DecisionTreeClassifier(criterion='entropy', splitter='random', max_depth=21, min_samples_split=3,random_state=40,)
#classifier = tree.DecisionTreeClassifier(criterion='entropy',splitter='random',max_depth=None,min_samples_split=3,min_samples_leaf=2,min_weight_fraction_leaf=0.0,max_features=None,random_state=None,max_leaf_nodes=None,min_impurity_decrease=0.0,min_impurity_split=None,class_weight=None,)

x_train = x_train.reshape(60000, 784)#第一个数字是用来索引图片,第二个数字是索引每张图片的像素点
x_test = x_test.reshape(10000, 784)

classifier.fit(x_train, y_train)
score = classifier.score(x_test, y_test)
print(score)
#tree.plot_tree(classifier)
#import matplotlib.pyplot as plt
#plt.show()
#后三行是画出决策树

3.2テストプロセス

以下は、私の個人的なテストで最良の結果が得られたパラメーターです。

criterion='entropy', 
splitter='random', 
max_depth=21, 
min_samples_split=3,
random_state=40

ここに画像の説明を挿入

図3max_depthのデバッグ結果

図3は、max_depthのパラメーターデバッグ曲線です。テストの正しいレートは86%から88%の間であることがわかります。max_depthが21より大きい場合、
正しいレートは88.6%で安定し、変動は発生しません。
[外部リンクの画像転送に失敗しました。ソースサイトにリーチ防止メカニズムがある可能性があります。画像を保存して直接アップロードすることをお勧めします(img-tpO2Kacz-1647868130461)(C:\ Users \ asus \ AppData \ Roaming \ Typora \ typora -user-images \ image-20220321105438659.png)]

図4ディシジョンツリーのテスト結果

4.注意を払う

4.1MNISTデータセットの概要

MNISTデータセットは、米国国立標準技術研究所からのものです。

米国国立標準技術研究所(NIST)

60,000のトレーニング画像と10,000のテスト画像が含まれています

各画像のサイズは28*28=784です

ここに画像の説明を挿入

図5データセット画像

次のパッケージをpython3.6環境にインストールする必要があります

pip install sklearn

pip install numpy

pip install matplotlib#这个包是画图用的

4.2MNISTデータセットの取得

4.2.1関数呼び出しを介して

from sklearn.datasets import load_digits
mnist = load_digits()

4.2.2ローカルフォルダにダウンロードしてから、

ここからダウンロードしてローカルに呼び出しました

ダウンロードアドレス1:mnist.npzデータセットは無料です-プログラマーが求めています

ダウンロードアドレス2:mnistデータセットのダウンロード-mnistデータセットはBaiduネットワークディスクのダウンロードアドレス_bigcindyのブログを提供します-CSDNblog_mnistデータセット

dataset = np.load('你的数据集位置')##获取数据集
#这里建议适用绝对路径位置
数据集下载——mnist数据集提供百度网盘下载地址_bigcindy的博客-CSDN博客_mnist数据集](https://blog.csdn.net/Jwenxue/article/details/89847251)

```python
dataset = np.load('你的数据集位置')##获取数据集
#这里建议适用绝对路径位置

おすすめ

転載: blog.csdn.net/m0_37758063/article/details/123646270