O TensorFlow define sua própria função de callback EarlyStop monitorando o indicador de perda

O modelo de treinamento do TensorFlow precisa passar por várias épocas, mas não é que quanto mais épocas, melhor. É muito provável que quando metade das épocas for treinada, o efeito do modelo comece a diminuir. É por isso que precisamos interromper o treinamento e salvar o modelo a tempo. Para atender a esse requisito, podemos personalizar a função de retorno de chamada para detectar automaticamente a perda do modelo. Desde que um determinado limite seja atingido, interrompemos manualmente o treinamento do modelo.

código completo

"""
 * Created with PyCharm
 * 作者: 阿光
 * 日期: 2022/1/4
 * 时间: 10:32
 * 描述:
"""
import numpy as np
import tensorflow as tf
from keras import Model
from tensorflow import keras
from tensorflow.keras.layers import *


def get_model():
    inputs = Input(shape=(784,))
    outputs = Dense(1)(inputs)
    model = Model(inputs, outputs)
    model.compile(
        optimizer=keras.optimizers.RMSprop(learning_rate=0.1),
        loss='mean_squared_error',
        metrics=['mean_absolute_error']
    )
    return model


(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
x_train = x_train.reshape(-1, 784).astype('float32') / 255.0
x_test = x_test.reshape(-1, 784).astype('float32') / 255.0

x_train = x_train[:1000]
y_train = y_train[:1000]
x_test = x_test[:1000]
y_test = y_test[:1000]


class CustomEarlyStoppingAtMinLoss(keras.callbacks.Callback):
    def __init__(self, patience=0):
        super(CustomEarlyStoppingAtMinLoss, self).__init__()
        self.patience = patience
        self.best_weights = None
        self.wait = 0
        self.stopped_epoch = 0
        self.best = np.Inf

    def on_train_begin(self, logs=None):
        pass

    def on_epoch_end(self, epoch, logs=None):
        current = logs.get("loss")
        if np.less(current, self.best):
            self.best = current
            self.wait = 0
            self.best_weights = self.model.get_weights()
        else:
            self.wait += 1
            if self.wait >= self.patience:
                self.stopped_epoch = epoch
                self.model.stop_training = True
                print("Restoring model weights from the end of the best epoch.")
                self.model.set_weights(self.best_weights)

    def on_train_end(self, logs=None):
        if self.stopped_epoch > 0:
            print("Epoch %05d: early stopping" % (self.stopped_epoch + 1))


model = get_model()
model.fit(
    x_train,
    y_train,
    batch_size=128,
    epochs=10,
    verbose=1,
    validation_split=0.5,
    callbacks=[CustomEarlyStoppingAtMinLoss()],
)

Acho que você gosta

Origin blog.csdn.net/m0_47256162/article/details/122298007
Recomendado
Clasificación