Pythonは手書き数字認識を実現します

手書き数字の認識は、手書きの写真を認識して数字を判断する、古典的な機械学習の問題です。

番号のカテゴリが0〜9であるため、非常にタイプの問題です。

この記事では、手書き数字の認識を実現するための例としてKNNアルゴリズムを取り上げます。

低次元の手書き数字認識

sklearnには、datasets.load_digits()によって呼び出される独自の手書き数字データセットがあります。

load_digitsの概要

load_digitsによって返される数字データセットには1797個のデータがあり、データの次元は64です。

数字はBunch型の辞書のようなオブジェクトであり、インデックスを使用して呼び出すことができます

転送

from sklearn import datasets
digits = datasets.load_digits()

インデックス

桁は5つの部分で構成されます:
ここに画像の説明を挿入
データ:データ、各要素は64次元のベクター
ここに画像の説明を挿入
画像:画像、各要素は8×8の行列
ここに画像の説明を挿入
ターゲット:各データに対応するラベル
ここに画像の説明を挿入
target_names:すべてのカテゴリラベル
ここに画像の説明を挿入

例として0番目の要素を取り上げます。

64次元のベクトル
ここに画像の説明を挿入
8×8行列で、数値「0」の輪郭を大まかに見ることができ
ここに画像の説明を挿入
ます。plt.imshow()を使用して画像を視覚化します。
ここに画像の説明を挿入
データカテゴリラベル
ここに画像の説明を挿入
。モデルの予測結果:
ここに画像の説明を挿入

トレーニングと予測にKNNを使用する

データセットを分割し、トレーニングセットを使用してknnをトレーニングしてから、テストセットを使用してパフォーマンスをテストします。

from sklearn.model_selection import train_test_split
from sklearn.neighbors import KNeighborsClassifier

X_train, X_test, y_train, y_test = train_test_split(digits['data'], digits['target'])
knn = KNeighborsClassifier()
knn.fit(X_train, y_train)
knn.score(X_test, y_test)

スコア関数は、テストセット上のモデルの正しいレートを次の形式で返します。

model.score(X_test, y_test)

正しいレートが非常に高いことがわかります。
ここに画像の説明を挿入

高次元の手書き数字認識

データセットのインポート

ローカルからmnist.npyファイルをインポートします

import numpy as np
x_train, x_test, y_train, y_test = np.load('data/mnist/mnist.npy', allow_pickle = True)

トレーニングセットの形状
ここに画像の説明を挿入
:テストセットの形状
ここに画像の説明を挿入

データセットの形状を変更する

データの形状を確認すると、各データが28×28の行列であることがわかります。

行列に対して直接演算を実行することはできないため、行列を784のベクトルに変換する必要があります。

# reshape训练集
n_samples, n1, n2 = x_train.shape
x_train = x_train.reshape(n_samples, n1*n2).astype(np.float32)
# reshape测试集
n_samples_test, n1_test, n2_test = x_test.shape
x_test = x_test.reshape(n_samples_test, n1_test*n2_test).astype(np.float32)

このとき、各データは784のベクトルになります。
ここに画像の説明を挿入
ここに画像の説明を挿入

特徴の次元削減

トレーニングセットには60,000のデータがあるため、時間がかかります。次元を削減して実行時間を短縮できます。

PCA主成分分析を使用して、次元を64次元に縮小します

# 特征降维
from sklearn.decomposition import PCA
pca = PCA(n_components = 64) 
decom_x_train = pca.fit_transform(x_train)
decom_x_test = pca.transform(x_test)

5行目のpca.transformは以前はpca.fit_transformとして記述されていたため、正解率は非常に低く、変更後は問題ありません。

テストセットの正しいレート:
ここに画像の説明を挿入

おすすめ

転載: blog.csdn.net/weixin_43772166/article/details/112066938