Animal recognition project 2 based on CNN algorithm: VGG + fully connected layer merging model
resource
15 common animal recognition data sets
1. Introduction to data sets
(1) The data set is divided into two parts: training set train and test set test
(2) Animal categories: bird, cat, cattle, chicken, dog, dolphin, duck, elephant, giraffe, monkey, pig, rabbit, rat, sheep, tiger.
(3) There are 200 pictures of each type of animal in the train data set
(4) There are 20 pictures of each type of animal in the test data set.
2. Development steps
1. Import the library
from keras.applications.vgg16 import VGG16
from keras.models import Sequential
from keras.layers import Dropout,Flatten,Dense
from keras.optimizers import SGD
from keras.preprocessing.image import ImageDataGenerator
import matplotlib.pyplot as plt
from keras_preprocessing.image import img_to_array,load_img
from keras.models import load_model
import numpy as np
2. Define the model
vgg16_model = VGG16(weights='imagenet',include_top=False,input_shape=(150,150,3))
#搭建全连接层
top_model = Sequential()
top_model.add(Flatten(input_shape=vgg16_model.output_shape[1:]))
top_model.add(Dense(256,activation='relu'))
top_model.add(Dropout(0.5))
top_model.add(Dense(15,activation='softmax'))
#两个模型进行合并
model = Sequential()
model.add(vgg16_model)
model.add(top_model)
model.summary()
3. Define the optimizer
model.compile(optimizer=SGD(lr=1e-4,momentum=0.9),loss='categorical_crossentropy',metrics=['accuracy'])
4. Training data enhancement
train_datagen = ImageDataGenerator(
rotation_range=40,
width_shift_range=0.2, #随机旋转度数
height_shift_range=0.2, #随机水平位移
rescale=1/255, #数据归一化
shear_range=0.2, #随机裁剪
zoom_range=0.2, #随机放大
horizontal_flip=True, #水平翻转
fill_mode='nearest', #填充方式
)
5. Test data normalization
test_data = ImageDataGenerator(
rescale=1/255, #数据归一化
)
6. Data generation
# 定义数据生成
batch_size = 32 #每次传32张照片
#生成训练数据
train_generator = train_datagen.flow_from_directory(
'/BASICCNN/image/train',
target_size=(150,150),
batch_size=batch_size,
)
#生成测试数据
test_generator = test_data.flow_from_directory(
'/BASICCNN/image/test',
target_size=(150,150),
batch_size=batch_size,
)
7. View category definitions
print(train_generator.class_indices)
8.Train the model
history=model.fit_generator(train_generator,epochs=10,validation_data=test_generator)
model.save('/BASICCNN/TrainModel_h5/model_VGG16Train.h5')
9. Plot training and validation results
plt.plot(history.history['accuracy'])
plt.plot(history.history['val_accuracy'])
plt.title('Model_Accuracy')
plt.ylabel('Accuracy')
plt.xlabel('Epochs')
plt.legend(['Train_Accuracy','Valid_Accuracy'],loc='upper left')
plt.savefig('/BASICCNN/TrainImage/VGG16Train_accuracy.png')
plt.show()
plt.plot(history.history['loss'])
plt.plot(history.history['val_loss'])
plt.title('Model_Loss')
plt.ylabel('Loss')
plt.xlabel('Epochs')
plt.legend(['Train_Loss','Valid_Loss'],loc='upper left')
plt.savefig('/BASICCNN/TrainImage/VGG16Train_loss.png')
plt.show()
Testing and visualization part reference: Animal recognition project 1 based on CNN algorithm custom model