Python は畳み込みニューラル ネットワーク LeNet-5 と AlexNet のトレーニングと認識を実装します

リソースのダウンロード アドレス: https://download.csdn.net/download/sheziqiong/88284348
リソースのダウンロード アドレス: https://download.csdn.net/download/sheziqiong/88284348

CNN畳み込みニューラルネットワーク

実験内容と要件

  • LeNet-5 畳み込みニューラル ネットワークを実装し、MNIST 手書き数字データベースを学習および認識し、精度などを表示するプログラムを作成します。
  • 独自のニューラル ネットワークを選択し、CIFAR-10 データベース上で画像オブジェクトのトレーニングと認識を実行します。

実験装置

Python 3.7

開発プラットフォーム:Windows10 Visual Studio Code

機械学習ライブラリ: torch 1.6.0 torchvision 0.7.0

補助: GPU アクセラレーション用の CUDA 10.2

実装

3.1 LeNet-5の実装

torch の nn.Module クラスの派生を使用すると、LeNet5 の構造は次のように記述できます。 nn.Conv2d() 関数を呼び出して畳み込み層を設定し、nn.Linear() 関数を使用して畳み込み層を設定します。完全接続操作。順方向伝導のプロセスでは、F.max_pool2d 関数を使用して 2 つのプーリングが指定されます。各レイヤーの後、結果に対して F.relu() 関数が呼び出され、結果がアクティブ化されて新しい出力が形成されます。

外部リンク画像の転送に失敗しました。ソース サイトにはリーチ防止メカニズムがある可能性があります。画像を保存して直接アップロードすることをお勧めします。

畳み込みニューラル ネットワークを実装するプロセスでは、pytorch のデータ読み込みモジュールを呼び出すことが困難になります。torch.utils.data.DataLoader()を呼び出し、バッチサイズ、ランダム再編成の有無、num_workers(プロセス数)を設定 Windowsを使用しているため、マルチスレッド対応は良くありません。

トレーニング プロセス: 最適化関数オプティマイザー (Adam アルゴリズムを使用) と損失関数 (クロス エントロピー関数 CrossEntropyLoss) を使用し、損失に対して backard() 関数を呼び出してバックプロパゲーション プロセスを実行します。トレーニング前にネットワークの train() 設定に注意し、ネットワークの過剰適合を防ぐためにバッチ正規化とドロップアウトを有効にしてください。

テストプロセス: eval() モードを有効にし、入力データをネットワークに伝播し、出力の最大値を予測結果 pred として取得します。

3.2 AlexNetの実装

ネットワーク定義は次のとおりです。

データはトレーニング前に前処理する必要があり、torchvision の処理関数を使用してサイズ変更してテンソルに変換することに注意してください。さらに、Normalize 関数が呼び出されて、元のテンソルが (0,1) から (-1,1) の範囲に変換されます。

CIFAR-10 のトレーニングと検出は MNIST のトレーニングと検出に似ているため、再度説明しません。

実験結果と分析

4.1 LeNet-5 のトレーニングと MNIST の認定

BATCH_SIZE を 512 に設定し、合計 10 エポックの間トレーニングします。各エポックはトレーニング データを渡し、次にテスト データを渡して、精度と損失関数の値を取得します。トレーニングとテストの出力結果は LeNet.log に保存され、モデルは LeNet.pth として保存されます。

トレーニング結果は次のように視覚化されます。

4.2 AlexNet による CIFAR-10 のトレーニングと認定

BATCH_SIZE を 32 に設定し、合計 20 エポックの間トレーニングします。各エポックはトレーニング データを渡し、次にテスト データを渡して、精度と損失関数の値を取得します。トレーニングとテストの出力結果は AlexNet.log に保存され、モデルは AlexNet.pth として保存されます。

AlexNet ネットワークは比較的複雑で、CIFAR-10 データ量も大きいため、トレーニングされたネットワーク構造が正しいかどうかを確認するために次のように出力されます。

まず、トレーニング結果をテストするためにデータのバッチをランダムに選択します。

実際のラベルと予測ラベルを比較すると、32 枚中 27 枚の画像が正しく判定され、正解率は約 84% でした。

GroundTruth:  cat  ship  ship airplane  frog  frog  automobile  frog   cat   automobile  airplane truck   dog horse truck  ship   dog horse  ship  frog horse  airplane  deer  truck
dog   bird  deer airplane truck  frog  frog   dog 
Predicted:    cat  ship  ship airplane  frog  frog  truck     frog   cat   automobile airplane  truck   dog horse truck  ship   dog horse  ship  frog horse  bird      airplane truck  deer  frog  deer airplane truck  frog  frog   dog 

さらに、50,000 個のトレーニング データに対するテスト結果は 92% の精度を示し、10,000 個の新しいテスト データに対する結果は 77% でした。10 個のラベルの中で、ship の正解率は 91% と最も高く、cat の正解率は 60% 近くと最も低いです。

リソースのダウンロード アドレス: https://download.csdn.net/download/sheziqiong/88284348
リソースのダウンロード アドレス: https://download.csdn.net/download/sheziqiong/88284348

おすすめ

転載: blog.csdn.net/newlw/article/details/132625381