深度学习之“Transfer Learning”

代码:

from keras.applications import ResNet50
from keras.models import Sequential
from keras.layers import Dense, Flatten, GlobalAveragePooling2D

num_classes = 2#classes
resnet_weights_path = 'resnet50_weights_tf_dim_ordering_tf_kernels_notop.h5'

my_new_model = Sequential()
my_new_model.add(ResNet50(include_top=False, pooling='avg', weights='imagenet'))
my_new_model.add(Dense(num_classes, activation='softmax'))

# Say not to train first layer (ResNet) model. It is already trained
my_new_model.layers[0].trainable = False

# We are calling the compile command for some python object. 
# Which python object is being compiled? Fill in the answer so the compile command works.
my_new_model.compile(optimizer='sgd', loss='categorical_crossentropy', metrics=['accuracy'])
from keras.applications.resnet50 import preprocess_input
from keras.preprocessing.image import ImageDataGenerator

image_size = 224
#data_generator = ImageDataGenerator(preprocessing_function=preprocess_input)会出错
data_generator = ImageDataGenerator() 

train_generator = data_generator.flow_from_directory(
       directory = '..\\Using Transfer Learning\\images\\train',
        target_size=(image_size, image_size),
    shuffle=True,
       batch_size=22,
    class_mode='categorical')

print('classes of train_generator:',train_generator.class_indices)

validation_generator = data_generator.flow_from_directory(
        directory ='..\\Using Transfer Learning\\images\\val',
        target_size=(image_size, image_size),
        class_mode='categorical')
print('classes of train_generator:',validation_generator.class_indices)
my_new_model.fit_generator(
        train_generator,
        epochs=1,
        steps_per_epoch=4,
        validation_data=validation_generator,
        validation_steps=6)

结果:

环境配置文件:

https://pan.baidu.com/s/1fBzSbJekdorXo7ZRGSZzig  tzvm

参考文档:

https://www.kaggle.com/dansbecker/exercise-using-transfer-learning/notebook

https://www.kaggle.com/dansbecker/transfer-learning/notebook

https://keras-cn.readthedocs.io/en/latest/preprocessing/image/#imagedatagenerator

https://keras-cn.readthedocs.io/en/latest/models/sequential/#fit_generator

https://keras.io/applications/#resnet50

猜你喜欢

转载自blog.csdn.net/wxf2012301351/article/details/80238249