Keras Fine Tuning(微调)(1)

版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/github_39611196/article/details/86422080

目录

Keras Fine Tuning(微调)(1)

Keras Fine Tuning(微调)(2)

Keras Fine Tuning(微调)(3)

数据集下载:https://download.csdn.net/download/github_39611196/10940372


本文主要介绍Keras中的fine tuning(微调),通过对西瓜、南瓜、番茄数据集进行分类来进行实例说明。

数据集示例:

导入相应模块和数据集:

import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
import keras
from keras.preprocessing.image import ImageDataGenerator, load_img

train_dir = './clean-dataset/train'
validation_dir = './clean-dataset/validation/'
image_size = 224

1、冻结所有层:

训练模型:

from keras.applications import VGG16

# 加载VGG模型
vgg_conv = VGG16(weights='imagenet', include_top=False, input_shape=(image_size, image_size, 3))

# 冻结所有层
for layer in vgg_conv.layers[:]:
    layer.trainable = False

# 检查每一层trainable属性的状态
for layer in vgg_conv.layers:
    print(layer, layer.trainable)

from keras import models
from keras import layers
from keras import optimizers

# 创建模型
model = models.Sequential()

# 添加vgg卷积基本模型
model.add(vgg_conv)

# 添加新的层
model.add(layers.Flatten())
model.add(layers.Dense(1024, activation='relu'))
model.add(layers.Dropout(0.5))
model.add(layers.Dense(3, activation='softmax'))  # softmax用于分类的激活函数

# 显示网络的summary
model.summary()

# 没有进行数据增强(data augmentation)
train_datagen = ImageDataGenerator(rescale=1. /255)
validation_datagen = ImageDataGenerator(rescale=1. /255)

# batchsize
train_batchsize = 100
val_batchsize = 10

# 训练数据的数据生成器
train_generator = train_datagen.flow_from_directory(train_dir, target_size=(image_size, image_size), batch_size=train_batchsize, class_mode='categorical')  # categorical用于分类的时候使用

# 验证数据的数据生成器
validation_generator = validation_datagen.flow_from_directory(validation_dir, target_size=(image_size, image_size), batch_size=val_batchsize, class_mode='categorical', shuffle=False)

# 编译模型
model.compile(loss='categorical_crossentropy', optimizer=optimizers.RMSprop(lr=1e-4), metrics=['acc'])

# 训练模型
history = model.fit_generator(train_generator, steps_per_epoch=train_generator.samples/train_generator.batch_size, epochs=20, validation_data=validation_generator, validation_steps=validation_generator.samples/validation_generator.batch_size, verbose=1)

# 保存模型
model.save('all_freezed.h5')

# 显示正确率和损失曲线
acc = history.history['acc']
val_acc = history.history['val_acc']

loss = history.history['loss']
val_loss = historty.history['val_loss']

epochs = range(len(acc))

plt.plot(epochs, acc, 'b', label='Training acc')
plt.plot(epochs, val_acc, 'r', label='Validation acc')
plt.title('Training and validation accuracy')
plt.legend()

plt.figure()

plt.plot(epochs, loss, 'b', label='Training loss')
plt.plot(epochs, val_loss, 'r', label='Validation loss')
plt.title('Training and validation loss')
plt.legend()

plt.show()

accuracy和loss:

显示错误率:

# 创建预测数据的数据生成器
validation_generator = validation_datagen.flow_from_directory(validation_dir, target_size=(image_size, image_size), batch_size=val_batchsize, class_mode='categorical', shuffle=False)

# 从生成器中获取文件名
fnames = validation_generator.filenames

# 从生成器中获得数据的ground truth
ground_truth = validation_generator.classes

# 从生成器获取标签到类索引的映射
label2index = validation_generator.class_indices

# 从生成器获取从类索引到标签的映射
idx2label = dict((v, k) for k, v in label2index.items())

# 预测值
predictions = model.predict_generator(validation_generator, steps=validation_generator.samples/validation_generator.batch_size, verbose=1)
predicted_classes = np.argmax(predictions, axis=1)

errors = np.where(predicted_classes != ground_truth)[0]
print("No of errors = {}/{}".format(len(errors),validation_generator.samples))

# 显示错误结果
for i in range(len(errors)):
    pred_class = np.argmax(predictions[errors[i]])
    pred_label = idx2label[pred_class]
    
    title = 'Original label:{}, Prediction :{}, confidence : {:.3f}'.format(
        fnames[errors[i]].split('/')[0],
        pred_label,
        predictions[errors[i]][pred_class])
    
    original = load_img('{}/{}'.format(validation_dir,fnames[errors[i]]))
    plt.figure(figsize=[7,7])
    plt.axis('off')
    plt.title(title)
    plt.imshow(original)
    plt.show()

预测结果:

错误数:

错误分类的图片:

猜你喜欢

转载自blog.csdn.net/github_39611196/article/details/86422080