Translation: Transfer learning & fine-tuning Transfer learning & fine-tuning

1 Introduction

Transfer learning involves taking features learned on one problem and using them on a new similar problem. For example, features from a model that has learned to recognize raccoons might help bootstrap a model designed to recognize civet cats.

Transfer learning is typically used for tasks where your dataset has too little data to train a full-sized model from scratch.

In the context of deep learning, the most common manifestation of transfer learning is the following workflow:

Extract layers from a previously trained model.
Freeze them to avoid destroying any information they contain in future training rounds.
Add some new trainable layers on top of the frozen layers. They will learn to convert old features into predictions for new datasets.
Train a new layer on the dataset.
The final optional step is fine-tuning, which involves unfreezing the entire model (or parts of it) you obtained above and retraining it on new data with a very low learning rate. This may lead to meaningful improvements by gradually adapting pretrained features to new data.

First, we'll detail the Keras trainable API, which is the basis for most transfer learning and fine-tuning workflows.

We will then demonstrate a typical workflow by taking a model pre-trained on the ImageNet dataset and retraining it on the Kaggle "Cats vs. Dogs" classification dataset.

2. Freezing layer: understand the trainable attribute

Layers & models have three weight properties:

  • weights is a list of all weight variables for this layer.
  • trainable_weights is a list of those to be updated (via gradient descent) to minimize loss during training.
  • non_trainable_weights is a list of those not going to be trained. Typically, they are updated by the model during the forward pass.

2.1 Settings

import numpy as np
import tensorflow as tf
from tensorflow import keras

2.2 Example: This Dense layer has 2 trainable weights (kernel and bias)

layer = keras.layers.Dense(3)
layer.build((None, 4))  # Create the weights

print("weights:", len(layer.weights))
print("trainable_weights:", len(layer.trainable_weights))
print("non_trainable_weights:", len(layer.non_trainable_weights))
weights: 2
trainable_weights: 2
non_trainable_weights: 0

In general, all weights are trainable weights. The only built-in layer with non-trainable weights is the BatchNormalization layer. It uses non-trainable weights to track the mean and variance of the inputs during training. To learn how to use non-trainable weights in your own custom layers, see the guide to writing new layers from scratch .

2.3 Example: The BatchNormalization layer has 2 trainable weights and 2 non-trainable weights

layer = keras.layers.BatchNormalization()
layer.build((None, 4))  # Create the weights

print("weights:", len(layer.weights))
print("trainable_weights:", len(layer.trainable_weights))
print("non_trainable_weights:", len(layer.non_trainable_weights))
weights: 4
trainable_weights: 2
non_trainable_weights: 2

Layers and models also have a boolean attribute trainable. Its value can be changed. Setting layer.trainable to False moves all layer weights from trainable to non-trainable. This is called a "frozen" layer: the state of the frozen layer is not updated during training (either at training time or during fit() training any custom loops that rely on trainable_weights to apply gradient updates).

2.4 Example: Set trainable to False

layer = keras.layers.Dense(3)
layer.build((None, 4))  # Create the weights
layer.trainable = False  # Freeze the layer

print("weights:", len(layer.weights))
print("trainable_weights:", len(layer.trainable_weights))
print("non_trainable_weights:", len(layer.non_trainable_weights))
weights: 2
trainable_weights: 0
non_trainable_weights: 2

When a trainable weight becomes non-trainable, its value is no longer updated during training.

# Make a model with 2 layers
layer1 = keras.layers.Dense(3, activation="relu")
layer2 = keras.layers.Dense(3, activation="sigmoid")
model = keras.Sequential([keras.Input(shape=(3,)), layer1, layer2])

# Freeze the first layer
layer1.trainable = False

# Keep a copy of the weights of layer1 for later reference
initial_layer1_weights_values = layer1.get_weights()

# Train the model
model.compile(optimizer="adam", loss="mse")
model.fit(np.random.random((2, 3)), np.random.random((2, 3)))

# Check that the weights of layer1 have not changed during training
final_layer1_weights_values = layer1.get_weights()
np.testing.assert_allclose(
    initial_layer1_weights_values[0], final_layer1_weights_values[0]
)
np.testing.assert_allclose(
    initial_layer1_weights_values[1], final_layer1_weights_values[1]
)
1/1 [==============================] - 0s 333ms/step - loss: 0.1007

Don't confuse attributes with parameters (layer.trainable which controls whether the layer should run its forward pass in inference mode or training mode). See the Keras FAQ for details . traininglayer. call ()

2.5 Recursive setting of trainable attribute

If you set trainable=False on the model or any layer that has sublayers, all sublayers will also become non-trainable.

example:

inner_model = keras.Sequential(
    [
        keras.Input(shape=(3,)),
        keras.layers.Dense(3, activation="relu"),
        keras.layers.Dense(3, activation="relu"),
    ]
)

model = keras.Sequential(
    [keras.Input(shape=(3,)), inner_model, keras.layers.Dense(3, activation="sigmoid"),]
)

model.trainable = False  # Freeze the outer model

assert inner_model.trainable == False  # All layers in `model` are now frozen
assert inner_model.layers[0].trainable == False  # `trainable` is propagated recursively

3. The typical transfer-learning workflow

This leads us to how to implement a typical transfer learning workflow in Keras:

  1. Instantiate a base model and load pretrained weights into it.
  2. Freeze all layers in the base model by setting trainable=False.
  3. Create a new model on top of the output of one (or more) layers of the base model.
  4. Train a new model on a new dataset.

Note that another more lightweight workflow could also be:

  1. Instantiate a base model and load pretrained weights into it.
  2. Run your new dataset through it and record the output of one (or more) layers in the underlying model. This is called feature extraction.
  3. Use this output as input data for a new, smaller model.

A key advantage of the second workflow is that you only need to run the base model on the data once, rather than once per training epoch. So it's faster and cheaper.

However, one problem with the second workflow is that it does not allow you to dynamically modify the input data for a new model during training, as is required for example when doing data augmentation. Transfer learning is often used for tasks when the new dataset has too little data to train a full-scale model from scratch, in which case data augmentation is important. So next, we'll focus on the first workflow.

Here's the first workflow in Keras:

First, instantiate a base model with pretrained weights.

base_model = keras.applications.Xception(
    weights='imagenet',  # Load weights pre-trained on ImageNet.
    input_shape=(150, 150, 3),
    include_top=False)  # Do not include the ImageNet classifier at the top.

Then, freeze the base model.

base_model.trainable = False

Create a new model above.

inputs = keras.Input(shape=(150, 150, 3))
# We make sure that the base_model is running in inference mode here,
# by passing `training=False`. This is important for fine-tuning, as you will
# learn in a few paragraphs.
x = base_model(inputs, training=False)
# Convert features of shape `base_model.output_shape[1:]` to vectors
x = keras.layers.GlobalAveragePooling2D()(x)
# A Dense classifier with a single unit (binary classification)
outputs = keras.layers.Dense(1)(x)
model = keras.Model(inputs, outputs)

Train the model on new data.

model.compile(optimizer=keras.optimizers.Adam(),
              loss=keras.losses.BinaryCrossentropy(from_logits=True),
              metrics=[keras.metrics.BinaryAccuracy()])
model.fit(new_dataset, epochs=20, callbacks=..., validation_data=...)

4. Fine-tuning

Once your model converges on new data, you can try unfreezing all or part of the base model and retrain the entire model end-to-end with a very low learning rate.

This is an optional last step that may give you incremental improvements. It can also lead to rapid overfitting - keep that in mind.

Crucially, this step is performed only after the model with frozen layers has been trained to convergence. If you mix randomly initialized trainable layers with trainable layers containing pretrained features, the randomly initialized layers will cause very large gradient updates during training, which will corrupt your pretrained features.

It's also important to use a very low learning rate at this stage, since you're training a much larger model than the first pass on a typically very small dataset. Therefore, if you apply large weight updates, you run the risk of overfitting very quickly. Here, you just want to incrementally rescale the pretrained weights.

This is how fine-tuning of the entire base model is achieved:

# Unfreeze the base model
base_model.trainable = True

# It's important to recompile your model after you make any changes
# to the `trainable` attribute of any inner layer, so that your changes
# are take into account
model.compile(optimizer=keras.optimizers.Adam(1e-5),  # Very low learning rate
              loss=keras.losses.BinaryCrossentropy(from_logits=True),
              metrics=[keras.metrics.BinaryAccuracy()])

# Train end-to-end. Be careful to stop before you overfit!
model.fit(new_dataset, epochs=10, callbacks=..., validation_data=...)

compile () important instructions trainable

Calling compile() on a model means "freezing" the behavior of that model. This means that the attribute values ​​at the time the trainable model was compiled should persist throughout the lifetime of that model until compile is called again. So if you change any trainable values, make sure to call compile() on your model again to have your changes taken into account.

Important note about BatchNormalization layers

Many image models contain BatchNormalization layers. This layer is every conceivable special case. Please keep the following points in mind.

  • BatchNormalization contains 2 non-trainable weights that are updated during training. These are variables that track the mean and variance of the inputs.
  • When you set bn_layer.trainable = False , the BatchNormalization layer will run in inference mode and will not update its mean and variance statistics. In general, this is not the case for other layers, since weight trainability and inference/training patterns are two orthogonal concepts. But in the case of layers, the two are parallel BatchNormalization.
  • When you unfreeze a model containing layers in BatchNormalization for fine-tuning, you should pass BatchNormalization to keep the layers in inference mode when calling the base model. training=False Otherwise, updates applied to non-trainable weights can suddenly destroy what the model has learned.

You'll see this pattern in action in the end-to-end example at the end of this guide.

5. Transfer learning & fine-tuning with a custom training loop

The workflow remains largely the same if you use your own low-level training loop instead of fit() . You should be careful to only consider the list model.trainable_weights when applying gradient updates:

# Create base model
base_model = keras.applications.Xception(
    weights='imagenet',
    input_shape=(150, 150, 3),
    include_top=False)
# Freeze base model
base_model.trainable = False

# Create new model on top.
inputs = keras.Input(shape=(150, 150, 3))
x = base_model(inputs, training=False)
x = keras.layers.GlobalAveragePooling2D()(x)
outputs = keras.layers.Dense(1)(x)
model = keras.Model(inputs, outputs)

loss_fn = keras.losses.BinaryCrossentropy(from_logits=True)
optimizer = keras.optimizers.Adam()

# Iterate over the batches of a dataset.
for inputs, targets in new_dataset:
    # Open a GradientTape.
    with tf.GradientTape() as tape:
        # Forward pass.
        predictions = model(inputs)
        # Compute the loss value for this batch.
        loss_value = loss_fn(targets, predictions)

    # Get gradients of loss wrt the *trainable* weights.
    gradients = tape.gradient(loss_value, model.trainable_weights)
    # Update the weights of the model.
    optimizer.apply_gradients(zip(gradients, model.trainable_weights))

Also used for fine tuning.

6. End-to-End Example: Fine-tuning an Image Classification Model on the Cats and Dogs Dataset

To solidify these concepts, let us walk you through a concrete example of end-to-end transfer learning and fine-tuning. We will load the Xception model pre-trained on ImageNet and use it on the Kaggle "Cats vs. Dogs" classification dataset.

6.1 Get data

First, let's get the cats and dogs dataset using TFDS. If you have your own dataset, you may wish to use the utility tf.keras.utils.image_dataset_from_directory to generate a similarly labeled dataset object from a set of images on disk, filed into class-specific folders .

Transfer learning is most useful when working with very small datasets. To keep our dataset small, we will use 40% of the original training data (25,000 images) for training, 10% for validation, and 10% for testing.

import tensorflow_datasets as tfds

tfds.disable_progress_bar()

train_ds, validation_ds, test_ds = tfds.load(
    "cats_vs_dogs",
    # Reserve 10% for validation and 10% for test
    split=["train[:40%]", "train[40%:50%]", "train[50%:60%]"],
    as_supervised=True,  # Include labels
)

print("Number of training samples: %d" % tf.data.experimental.cardinality(train_ds))
print(
    "Number of validation samples: %d" % tf.data.experimental.cardinality(validation_ds)
)
print("Number of test samples: %d" % tf.data.experimental.cardinality(test_ds))
Number of training samples: 9305
Number of validation samples: 2326
Number of test samples: 2326

These are the first 9 images in the training dataset - as you can see, they vary in size.

import matplotlib.pyplot as plt

plt.figure(figsize=(10, 10))
for i, (image, label) in enumerate(train_ds.take(9)):
    ax = plt.subplot(3, 3, i + 1)
    plt.imshow(image)
    plt.title(int(label))
    plt.axis("off")

insert image description here
We can also see that label 1 is "dog" and label 0 is "cat".

6.2 Normalized data

Our raw images come in a variety of sizes. Additionally, each pixel consists of 3 integer values ​​between 0 and 255 (RGB level values). This is not very suitable for feeding neural networks with data. We need to do two things:

  • Normalize to a fixed image size. We choose 150x150.
  • Normalizes pixel values ​​between -1 and 1. We'll do this using the Normalization layer as part of the model itself.

In general, it is good practice to develop models that take raw data as input, rather than models that take already preprocessed data. The reason is that if your model requires preprocessing data, any time you export your model to use it elsewhere (in a web browser, in a mobile app), you need to reimplement the exact same preprocessing pipeline. This can get very tricky very quickly. So we should do as little preprocessing as possible before modeling.

Here, we will resize the images in the data pipeline (since deep neural networks can only handle continuous batches of data), and we will scale the input values ​​as part of the model when we create it.

Let's resize the image to 150x150:

size = (150, 150)

train_ds = train_ds.map(lambda x, y: (tf.image.resize(x, size), y))
validation_ds = validation_ds.map(lambda x, y: (tf.image.resize(x, size), y))
test_ds = test_ds.map(lambda x, y: (tf.image.resize(x, size), y))

Also, let's batch the data and use caching and prefetching to optimize loading speed.

batch_size = 32

train_ds = train_ds.cache().batch(batch_size).prefetch(buffer_size=10)
validation_ds = validation_ds.cache().batch(batch_size).prefetch(buffer_size=10)
test_ds = test_ds.cache().batch(batch_size).prefetch(buffer_size=10)

6.3 Augmenting with Random Data

When you don't have a large dataset of images, it's good practice to artificially introduce sample diversity by applying random but realistic transformations to the training images, such as random horizontal flips or small random rotations. This helps expose the model to different aspects of the training data while slowing down overfitting.

from tensorflow import keras
from tensorflow.keras import layers

data_augmentation = keras.Sequential(
    [layers.RandomFlip("horizontal"), layers.RandomRotation(0.1),]
)

Let's imagine what the first image of the first batch would look like after various random transformations:

import numpy as np

for images, labels in train_ds.take(1):
    plt.figure(figsize=(10, 10))
    first_image = images[0]
    for i in range(9):
        ax = plt.subplot(3, 3, i + 1)
        augmented_image = data_augmentation(
            tf.expand_dims(first_image, 0), training=True
        )
        plt.imshow(augmented_image[0].numpy().astype("int32"))
        plt.title(int(labels[0]))
        plt.axis("off")

insert image description here

7. Build a model

Now let's build a model following the blueprint we explained earlier.

Notice:

  • We add a Rescaling layer to scale the input values ​​(initially in the [0, 255] range) to the [-1, 1] range.
  • We dropout to add a layer before the classification layer for regularization.
  • We make sure training=False is passed when calling the base model so that it runs in inference mode so that even if we unfreeze the base model for fine-tuning, the batchnorm statistics are not updated.
base_model = keras.applications.Xception(
    weights="imagenet",  # Load weights pre-trained on ImageNet.
    input_shape=(150, 150, 3),
    include_top=False,
)  # Do not include the ImageNet classifier at the top.

# Freeze the base_model
base_model.trainable = False

# Create new model on top
inputs = keras.Input(shape=(150, 150, 3))
x = data_augmentation(inputs)  # Apply random data augmentation

# Pre-trained Xception weights requires that input be scaled
# from (0, 255) to a range of (-1., +1.), the rescaling layer
# outputs: `(inputs * scale) + offset`
scale_layer = keras.layers.Rescaling(scale=1 / 127.5, offset=-1)
x = scale_layer(x)

# The base model contains batchnorm layers. We want to keep them in inference mode
# when we unfreeze the base model for fine-tuning, so we make sure that the
# base_model is running in inference mode here.
x = base_model(x, training=False)
x = keras.layers.GlobalAveragePooling2D()(x)
x = keras.layers.Dropout(0.2)(x)  # Regularize with dropout
outputs = keras.layers.Dense(1)(x)
model = keras.Model(inputs, outputs)

model.summary()
Model: "model"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_5 (InputLayer)         [(None, 150, 150, 3)]     0         
_________________________________________________________________
sequential_3 (Sequential)    (None, 150, 150, 3)       0         
_________________________________________________________________
rescaling (Rescaling)        (None, 150, 150, 3)       0         
_________________________________________________________________
xception (Functional)        (None, 5, 5, 2048)        20861480  
_________________________________________________________________
global_average_pooling2d (Gl (None, 2048)              0         
_________________________________________________________________
dropout (Dropout)            (None, 2048)              0         
_________________________________________________________________
dense_7 (Dense)              (None, 1)                 2049      
=================================================================
Total params: 20,863,529
Trainable params: 2,049
Non-trainable params: 20,861,480
_________________________________________________________________

8. Training the top layer

model.compile(
    optimizer=keras.optimizers.Adam(),
    loss=keras.losses.BinaryCrossentropy(from_logits=True),
    metrics=[keras.metrics.BinaryAccuracy()],
)

epochs = 20
model.fit(train_ds, epochs=epochs, validation_data=validation_ds)

Epoch 1/20
291/291 [==============================] - 133s 451ms/step - loss: 0.1670 - binary_accuracy: 0.9267 - val_loss: 0.0830 - val_binary_accuracy: 0.9716
Epoch 2/20
291/291 [==============================] - 135s 465ms/step - loss: 0.1208 - binary_accuracy: 0.9502 - val_loss: 0.0768 - val_binary_accuracy: 0.9716
Epoch 3/20
291/291 [==============================] - 135s 463ms/step - loss: 0.1062 - binary_accuracy: 0.9572 - val_loss: 0.0757 - val_binary_accuracy: 0.9716
Epoch 4/20
291/291 [==============================] - 137s 469ms/step - loss: 0.1024 - binary_accuracy: 0.9554 - val_loss: 0.0733 - val_binary_accuracy: 0.9725
Epoch 5/20
291/291 [==============================] - 137s 470ms/step - loss: 0.1004 - binary_accuracy: 0.9587 - val_loss: 0.0735 - val_binary_accuracy: 0.9729
Epoch 6/20
291/291 [==============================] - 136s 467ms/step - loss: 0.0979 - binary_accuracy: 0.9577 - val_loss: 0.0747 - val_binary_accuracy: 0.9708
Epoch 7/20
291/291 [==============================] - 134s 462ms/step - loss: 0.0998 - binary_accuracy: 0.9596 - val_loss: 0.0706 - val_binary_accuracy: 0.9725
Epoch 8/20
291/291 [==============================] - 133s 457ms/step - loss: 0.1029 - binary_accuracy: 0.9592 - val_loss: 0.0720 - val_binary_accuracy: 0.9733
Epoch 9/20
291/291 [==============================] - 135s 466ms/step - loss: 0.0937 - binary_accuracy: 0.9625 - val_loss: 0.0707 - val_binary_accuracy: 0.9721
Epoch 10/20
291/291 [==============================] - 137s 472ms/step - loss: 0.0967 - binary_accuracy: 0.9580 - val_loss: 0.0720 - val_binary_accuracy: 0.9712
Epoch 11/20
291/291 [==============================] - 135s 463ms/step - loss: 0.0961 - binary_accuracy: 0.9612 - val_loss: 0.0802 - val_binary_accuracy: 0.9699
Epoch 12/20
291/291 [==============================] - 134s 460ms/step - loss: 0.0963 - binary_accuracy: 0.9638 - val_loss: 0.0721 - val_binary_accuracy: 0.9716
Epoch 13/20
291/291 [==============================] - 136s 468ms/step - loss: 0.0925 - binary_accuracy: 0.9635 - val_loss: 0.0736 - val_binary_accuracy: 0.9686
Epoch 14/20
291/291 [==============================] - 138s 476ms/step - loss: 0.0909 - binary_accuracy: 0.9624 - val_loss: 0.0766 - val_binary_accuracy: 0.9703
Epoch 15/20
291/291 [==============================] - 136s 467ms/step - loss: 0.0949 - binary_accuracy: 0.9598 - val_loss: 0.0704 - val_binary_accuracy: 0.9725
Epoch 16/20
291/291 [==============================] - 133s 456ms/step - loss: 0.0969 - binary_accuracy: 0.9586 - val_loss: 0.0722 - val_binary_accuracy: 0.9708
Epoch 17/20
291/291 [==============================] - 135s 464ms/step - loss: 0.0913 - binary_accuracy: 0.9635 - val_loss: 0.0718 - val_binary_accuracy: 0.9716
Epoch 18/20
291/291 [==============================] - 137s 472ms/step - loss: 0.0915 - binary_accuracy: 0.9639 - val_loss: 0.0727 - val_binary_accuracy: 0.9725
Epoch 19/20
291/291 [==============================] - 134s 460ms/step - loss: 0.0938 - binary_accuracy: 0.9631 - val_loss: 0.0707 - val_binary_accuracy: 0.9733
Epoch 20/20
291/291 [==============================] - 134s 460ms/step - loss: 0.0971 - binary_accuracy: 0.9609 - val_loss: 0.0714 - val_binary_accuracy: 0.9716

<keras.callbacks.History at 0x7f4494e38f70>

9. Do a round of fine-tuning on the entire model

Finally, let's unfreeze the base model and train the whole thing end-to-end with a low learning rate.

Importantly, while the base model becomes trainable, it still runs in inference mode because we passed training=False when calling it when building the model. This means that the internal batch normalization layers do not update their batch statistics. If they did, they would destroy the representations the model has learned so far.

# Unfreeze the base_model. Note that it keeps running in inference mode
# since we passed `training=False` when calling it. This means that
# the batchnorm layers will not update their batch statistics.
# This prevents the batchnorm layers from undoing all the training
# we've done so far.
base_model.trainable = True
model.summary()

model.compile(
    optimizer=keras.optimizers.Adam(1e-5),  # Low learning rate
    loss=keras.losses.BinaryCrossentropy(from_logits=True),
    metrics=[keras.metrics.BinaryAccuracy()],
)

epochs = 10
model.fit(train_ds, epochs=epochs, validation_data=validation_ds)
Model: "model"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_5 (InputLayer)         [(None, 150, 150, 3)]     0         
_________________________________________________________________
sequential_3 (Sequential)    (None, 150, 150, 3)       0         
_________________________________________________________________
rescaling (Rescaling)        (None, 150, 150, 3)       0         
_________________________________________________________________
xception (Functional)        (None, 5, 5, 2048)        20861480  
_________________________________________________________________
global_average_pooling2d (Gl (None, 2048)              0         
_________________________________________________________________
dropout (Dropout)            (None, 2048)              0         
_________________________________________________________________
dense_7 (Dense)              (None, 1)                 2049      
=================================================================
Total params: 20,863,529
Trainable params: 20,809,001
Non-trainable params: 54,528
_________________________________________________________________
Epoch 1/10
291/291 [==============================] - 567s 2s/step - loss: 0.0749 - binary_accuracy: 0.9689 - val_loss: 0.0605 - val_binary_accuracy: 0.9776
Epoch 2/10
291/291 [==============================] - 551s 2s/step - loss: 0.0559 - binary_accuracy: 0.9770 - val_loss: 0.0507 - val_binary_accuracy: 0.9798
Epoch 3/10
291/291 [==============================] - 545s 2s/step - loss: 0.0444 - binary_accuracy: 0.9832 - val_loss: 0.0502 - val_binary_accuracy: 0.9807
Epoch 4/10
291/291 [==============================] - 558s 2s/step - loss: 0.0365 - binary_accuracy: 0.9874 - val_loss: 0.0506 - val_binary_accuracy: 0.9807
Epoch 5/10
291/291 [==============================] - 550s 2s/step - loss: 0.0276 - binary_accuracy: 0.9890 - val_loss: 0.0477 - val_binary_accuracy: 0.9802
Epoch 6/10
291/291 [==============================] - 588s 2s/step - loss: 0.0206 - binary_accuracy: 0.9916 - val_loss: 0.0444 - val_binary_accuracy: 0.9832
Epoch 7/10
291/291 [==============================] - 542s 2s/step - loss: 0.0206 - binary_accuracy: 0.9923 - val_loss: 0.0502 - val_binary_accuracy: 0.9828
Epoch 8/10
291/291 [==============================] - 544s 2s/step - loss: 0.0153 - binary_accuracy: 0.9939 - val_loss: 0.0509 - val_binary_accuracy: 0.9819
Epoch 9/10
291/291 [==============================] - 548s 2s/step - loss: 0.0156 - binary_accuracy: 0.9934 - val_loss: 0.0610 - val_binary_accuracy: 0.9807
Epoch 10/10
291/291 [==============================] - 546s 2s/step - loss: 0.0176 - binary_accuracy: 0.9936 - val_loss: 0.0561 - val_binary_accuracy: 0.9789

<keras.callbacks.History at 0x7f4495056040>

Fine-tuning gives us a nice improvement here after 10 epochs.

reference

https://keras.io/guides/transfer_learning/

Guess you like

Origin blog.csdn.net/zgpeace/article/details/130445306
Recommended