一个例子了解迁移学习

版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/wangyangzhizhou/article/details/84981876

迁移学习

对于传统机器学习而言,要求训练样本与测试样本满足独立同分布,而且必须要有足够多的训练样本。而迁移学习能把一个领域(即源领域)的知识,迁移到另外一个领域(即目标领域),目标领域往往只有少量有标签样本,使得目标领域能够取得更好的学习效果。

image

迁移方式

  • 样本迁移,在源领域中找出与目标领域相似的样本,增加该样本的权重,使其在预测目标与的比重加大。
  • 特征迁移,源领域与目标领域包含共同的交叉特征,通过特征变换将源领域和目标领域的的特征变换到相同空间,使它们具有相同分布。
  • 模型迁移,源领域和目标领域共享模型参数,将源领域已训练好的网络模型应用到目标领域的新问题上。
  • 关系迁移,源领域和目标领域具有某种相似关系,可以将源领域的逻辑关系应用到目标领域中。

模型迁移

这里基于预训练的卷积神经网络训练一组新参数,然后将其用于分类任务,这样就能共享模型参数,避免了从头开始训练模型的参数,大大减少训练时间。

数据集

在示例中使用flower17数据集,它是一个包含17种花卉类别的数据集,每个类别有80张图像。收集的花都是英国一些常见的花,这些图像具有大比例、不同姿态和光线变化等性质。

使用水仙花和款冬这两类花,并且在预训练的VGG16网络之上构建分类器。

image

image

实现

首先导入所有必需的库,包括应用程序、预处理、模型检查点以及相关对象,cv2库和NumPy库用于图像处理和数值的基本操作。

from keras import applications
from keras.preprocessing.image import ImageDataGenerator
from keras import optimizers
from keras.models import Model
from keras.layers import Dropout, Flatten, Dense
from keras.callbacks import ModelCheckpoint, EarlyStopping
from keras.applications.vgg16 import preprocess_input
import cv2
import numpy as np

定义输入、数据源及与训练参数相关的所有变量。

img_width, img_height = 224, 224
train_data_dir = "data/train"
validation_data_dir = "data/validation"
nb_train_samples = 300
nb_validation_samples = 100
batch_size = 16
epochs = 1

调用VGG16预训练模型,其中不包括顶部的平整化层。冻结不参与训练的层,这里我们冻结前五层,然后添加自定义层,从而创建最终的模型。

model = applications.VGG16(weights="imagenet", include_top=False, input_shape=(img_width, img_height, 3))
for layer in model.layers[:5]:
    layer.trainable = False
x = model.output
x = Flatten()(x)
x = Dense(1024, activation="relu")(x)
x = Dropout(0.5)(x)
x = Dense(1024, activation="relu")(x)
predictions = Dense(2, activation="softmax")(x)
model_final = Model(inputs=model.input, output=predictions)

接着开始编译模型,并为训练、测试数据集创建图像数据增强生成器。

model_final.compile(loss="categorical_crossentropy", optimizer=optimizers.SGD(lr=0.0001, momentum=0.9),
                    metrics=["accuracy"])
train_datagen = ImageDataGenerator(rescale=1. / 255, horizontal_flip=True, fill_mode="nearest", zoom_range=0.3,
                                   width_shift_range=0.3, height_shift_range=0.3, rotation_range=30)
test_datagen = ImageDataGenerator(rescale=1. / 255, horizontal_flip=True, fill_mode="nearest", zoom_range=0.3,
                                  width_shift_range=0.3, height_shift_range=0.3, rotation_range=30)

生成增强后新的数据,根据情况保存模型。

train_generator = train_datagen.flow_from_directory(train_data_dir, target_size=(img_height, img_width),
                                                    batch_size=batch_size, class_mode="categorical")
validation_generator = test_datagen.flow_from_directory(validation_data_dir, target_size=(img_height, img_width),
                                                        class_mode="categorical")
checkpoint = ModelCheckpoint("vgg16_1.h5", monitor='val_acc', verbose=1, save_best_only=True, save_weights_only=False,
                             mode='auto', period=1)
early = EarlyStopping(monitor='val_acc', min_delta=0, patience=10, verbose=1, mode='auto')

开始对模型中新的网络层进行拟合。

model_final.fit_generator(train_generator, samples_per_epoch=nb_train_samples, nb_epoch=epochs,
                          validation_data=validation_generator, nb_val_samples=nb_validation_samples,
                          callbacks=[checkpoint, early])

练完成后用水仙花图像测试这个新模型,输出的正确值应该为接近[1.,0.]的数组。

im = cv2.resize(cv2.imread('data/test/gaff2.jpg'), (img_width, img_height))
im = np.expand_dims(im, axis=0).astype(np.float32)
im = preprocess_input(im)
out = model_final.predict(im)
print(out)
print(np.argmax(out))
 1/18 [>.............................] - ETA: 16:43 - loss: 0.9380 - acc: 0.3750
 2/18 [==>...........................] - ETA: 13:51 - loss: 0.8720 - acc: 0.4062
 3/18 [====>.........................] - ETA: 12:32 - loss: 0.8382 - acc: 0.4167
 4/18 [=====>........................] - ETA: 10:53 - loss: 0.8103 - acc: 0.4663
 5/18 [=======>......................] - ETA: 10:00 - loss: 0.8208 - acc: 0.4606
 6/18 [=========>....................] - ETA: 9:12 - loss: 0.8083 - acc: 0.4567 
 7/18 [==========>...................] - ETA: 8:24 - loss: 0.7891 - acc: 0.4718
 8/18 [============>.................] - ETA: 7:37 - loss: 0.7994 - acc: 0.4832
 9/18 [==============>...............] - ETA: 6:51 - loss: 0.7841 - acc: 0.4850Epoch 00001: val_acc improved from -inf to 0.40000, saving model to vgg16_1.h5

 9/18 [==============>...............] - ETA: 7:16 - loss: 0.7841 - acc: 0.4850 - val_loss: 0.0000e+00 - val_acc: 0.0000e+00[[0.2213877  0.77861226]]

github

https://github.com/sea-boat/DeepLearning-Lab/blob/master/transfer_learning.py

-------------推荐阅读------------

我的开源项目汇总(机器&深度学习、NLP、网络IO、AIML、mysql协议、chatbot)

为什么写《Tomcat内核设计剖析》

我的2017文章汇总——机器学习篇

我的2017文章汇总——Java及中间件

我的2017文章汇总——深度学习篇

我的2017文章汇总——JDK源码篇

我的2017文章汇总——自然语言处理篇

我的2017文章汇总——Java并发篇


跟我交流,向我提问:

欢迎关注:

猜你喜欢

转载自blog.csdn.net/wangyangzhizhou/article/details/84981876
今日推荐