CNN アルゴリズムに基づく動物認識プロジェクト 2: VGG + 全結合層マージ モデル

リソース

15 の一般的な動物認識データセット


1. データセットの概要

(1) データセットは 2 つの部分に分かれています: トレーニング セット train とテスト セット test
(2) 動物カテゴリ: 鳥、猫、牛、鶏、犬、イルカ、アヒル、ゾウ、キリン、サル、ブタ、ウサギ、ラット、羊、虎。
(3) 列車データ セットには各種類の動物の写真が 200 枚あります。
(4) テスト データ セットには各種類の動物の写真が 20 枚あります。

2. 開発手順

1. ライブラリをインポートする

from keras.applications.vgg16 import VGG16
from keras.models import Sequential
from keras.layers import Dropout,Flatten,Dense
from keras.optimizers import SGD
from keras.preprocessing.image import ImageDataGenerator
import matplotlib.pyplot as plt
from keras_preprocessing.image import img_to_array,load_img
from keras.models import load_model
import numpy as np

2. モデルを定義する

vgg16_model = VGG16(weights='imagenet',include_top=False,input_shape=(150,150,3))

#搭建全连接层
top_model = Sequential()
top_model.add(Flatten(input_shape=vgg16_model.output_shape[1:]))
top_model.add(Dense(256,activation='relu'))
top_model.add(Dropout(0.5))
top_model.add(Dense(15,activation='softmax'))

#两个模型进行合并
model = Sequential()
model.add(vgg16_model)
model.add(top_model)
model.summary()

ここに画像の説明を挿入します

3. オプティマイザーを定義する

model.compile(optimizer=SGD(lr=1e-4,momentum=0.9),loss='categorical_crossentropy',metrics=['accuracy'])

4. 学習データの強化

train_datagen = ImageDataGenerator(
    rotation_range=40,
    width_shift_range=0.2, #随机旋转度数
    height_shift_range=0.2, #随机水平位移
    rescale=1/255, #数据归一化
    shear_range=0.2, #随机裁剪
    zoom_range=0.2, #随机放大
    horizontal_flip=True, #水平翻转
    fill_mode='nearest', #填充方式
)

5. テストデータの正規化

test_data = ImageDataGenerator(
    rescale=1/255, #数据归一化
)

6. データ生成

# 定义数据生成
batch_size = 32  #每次传32张照片

#生成训练数据
train_generator = train_datagen.flow_from_directory(
    '/BASICCNN/image/train',
    target_size=(150,150),
    batch_size=batch_size,
)

#生成测试数据
test_generator = test_data.flow_from_directory(
    '/BASICCNN/image/test',
    target_size=(150,150),
    batch_size=batch_size,
)

7. カテゴリの定義を表示する

print(train_generator.class_indices)

ここに画像の説明を挿入します
ここに画像の説明を挿入します

8.モデルをトレーニングする

history=model.fit_generator(train_generator,epochs=10,validation_data=test_generator)
model.save('/BASICCNN/TrainModel_h5/model_VGG16Train.h5')

ここに画像の説明を挿入します
ここに画像の説明を挿入します

9. トレーニングと検証の結果をプロットする

plt.plot(history.history['accuracy'])
plt.plot(history.history['val_accuracy'])
plt.title('Model_Accuracy')
plt.ylabel('Accuracy')
plt.xlabel('Epochs')
plt.legend(['Train_Accuracy','Valid_Accuracy'],loc='upper left')
plt.savefig('/BASICCNN/TrainImage/VGG16Train_accuracy.png')
plt.show()

plt.plot(history.history['loss'])
plt.plot(history.history['val_loss'])
plt.title('Model_Loss')
plt.ylabel('Loss')
plt.xlabel('Epochs')
plt.legend(['Train_Loss','Valid_Loss'],loc='upper left')
plt.savefig('/BASICCNN/TrainImage/VGG16Train_loss.png')
plt.show()

ここに画像の説明を挿入します
ここに画像の説明を挿入します
テストおよび視覚化パートのリファレンス: CNN アルゴリズム カスタム モデルに基づく動物認識プロジェクト 1

おすすめ

転載: blog.csdn.net/weixin_43312470/article/details/124329750