Función de devolución de llamada personalizada de TensorFlow [devolución de llamada global, lote, época]

Dado que TensorFlow ha encapsulado la fase de entrenamiento de todo el modelo, no podemos definir nuestro propio comportamiento durante el entrenamiento o la evaluación de la predicción, como imprimir el progreso del entrenamiento, guardar la precisión de la pérdida, etc. Aquí es donde podemos usar las funciones de devolución de llamada.

Todas las funciones de devolución de llamada subclasifican la clase keras.callbacks.Callback y anulan un conjunto de métodos que se llaman en varias etapas de entrenamiento, prueba y predicción. Las funciones de devolución de llamada son útiles para comprender el estado interno y las estadísticas del modelo durante el entrenamiento.

Puede pasar una lista de funciones de devolución de llamada (como devoluciones de llamada de argumento de palabra clave) a los siguientes métodos de modelo:

  • keras.Modelo.fit()
  • keras.Modelo.evaluar()
  • duro.Modelo.predecir ()

Descripción general del método de función de devolución de llamada
全局方法
on_(train|test|predict) begin(self, logs=None)
se llama al comienzo de fit/evaluate/predict.
on
(train|test|predict) end(self, logs=None)
se llama al final de fit/evaluate/predict.
批次级方法(仅训练)
on
(train|test|predict) se llama a batch_begin(self, batch, logs=None)
justo antes de que se procese el lote durante train/test/predict.
on
(train|test|predict)_batch_end(self, batch, logs=None)
se llama al final de un lote de entrenamiento/prueba/predicción. En este método, los registros son un diccionario que contiene los resultados de las métricas.
周期级方法(仅训练)
on_epoch_begin(self, epoch, logs=None)
se llama al comienzo de una época durante el entrenamiento.
on_epoch_end(self, epoch, logs=None)
se llama al comienzo de la época durante el entrenamiento.

código completo

"""
 * Created with PyCharm
 * 作者: 阿光
 * 日期: 2022/1/4
 * 时间: 10:02
 * 描述:
"""
import tensorflow as tf
from keras import Model
from tensorflow import keras
from tensorflow.keras.layers import *


def get_model():
    inputs = Input(shape=(784,))
    outputs = Dense(1)(inputs)
    model = Model(inputs, outputs)
    model.compile(
        optimizer=keras.optimizers.RMSprop(learning_rate=0.1),
        loss='mean_squared_error',
        metrics=['mean_absolute_error']
    )
    return model


(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
x_train = x_train.reshape(-1, 784).astype('float32') / 255.0
x_test = x_test.reshape(-1, 784).astype('float32') / 255.0

x_train = x_train[:1000]
y_train = y_train[:1000]
x_test = x_test[:1000]
y_test = y_test[:1000]


class CustomCallback(keras.callbacks.Callback):
    def on_train_begin(self, logs=None):
        keys = list(logs.keys())
        print("Starting training; got log keys: {}".format(keys))

    def on_train_end(self, logs=None):
        keys = list(logs.keys())
        print("Stop training; got log keys: {}".format(keys))

    def on_epoch_begin(self, epoch, logs=None):
        keys = list(logs.keys())
        print("Start epoch {} of training; got log keys: {}".format(epoch, keys))

    def on_epoch_end(self, epoch, logs=None):
        keys = list(logs.keys())
        print("End epoch {} of training; got log keys: {}".format(epoch, keys))

    def on_test_begin(self, logs=None):
        keys = list(logs.keys())
        print("Start testing; got log keys: {}".format(keys))

    def on_test_end(self, logs=None):
        keys = list(logs.keys())
        print("Stop testing; got log keys: {}".format(keys))

    def on_predict_begin(self, logs=None):
        keys = list(logs.keys())
        print("Start predicting; got log keys: {}".format(keys))

    def on_predict_end(self, logs=None):
        keys = list(logs.keys())
        print("Stop predicting; got log keys: {}".format(keys))

    def on_train_batch_begin(self, batch, logs=None):
        keys = list(logs.keys())
        print("...Training: start of batch {}; got log keys: {}".format(batch, keys))

    def on_train_batch_end(self, batch, logs=None):
        keys = list(logs.keys())
        print("...Training: end of batch {}; got log keys: {}".format(batch, keys))

    def on_test_batch_begin(self, batch, logs=None):
        keys = list(logs.keys())
        print("...Evaluating: start of batch {}; got log keys: {}".format(batch, keys))

    def on_test_batch_end(self, batch, logs=None):
        keys = list(logs.keys())
        print("...Evaluating: end of batch {}; got log keys: {}".format(batch, keys))

    def on_predict_batch_begin(self, batch, logs=None):
        keys = list(logs.keys())
        print("...Predicting: start of batch {}; got log keys: {}".format(batch, keys))

    def on_predict_batch_end(self, batch, logs=None):
        keys = list(logs.keys())
        print("...Predicting: end of batch {}; got log keys: {}".format(batch, keys))


model = get_model()
model.fit(
    x_train,
    y_train,
    batch_size=128,
    epochs=1,
    verbose=0,
    validation_split=0.5,
    callbacks=[CustomCallback()],
)

res = model.evaluate(
    x_test,
    y_test,
    batch_size=128,
    verbose=0,
    callbacks=[CustomCallback()]
)

res = model.predict(x_test,
                    batch_size=128,
                    callbacks=[CustomCallback()])

Supongo que te gusta

Origin blog.csdn.net/m0_47256162/article/details/122297340
Recomendado
Clasificación