TensorFlow自定义回调函数【全局回调、批次、epoch】

由于TensorFlow已经将整个模型的训练阶段进行了封装,所以我们无法在训练期间或者预测评估期间定义自己的行为,例如打印训练进度、保存损失精度等,这是我们就可以利用回调函数

所有回调函数都将 keras.callbacks.Callback 类作为子类,并重写在训练、测试和预测的各个阶段调用的一组方法。回调函数对于在训练期间了解模型的内部状态和统计信息十分有用。

您可以将回调函数的列表(作为关键字参数 callbacks)传递给以下模型方法:

  • keras.Model.fit()
  • keras.Model.evaluate()
  • keras.Model.predict()

回调函数方法概述
全局方法
on_(train|test|predict)begin(self, logs=None)
在 fit/evaluate/predict 开始时调用。
on
(train|test|predict)end(self, logs=None)
在 fit/evaluate/predict 结束时调用。
批次级方法(仅训练)
on
(train|test|predict)batch_begin(self, batch, logs=None)
正好在训练/测试/预测期间处理批次之前调用。
on
(train|test|predict)_batch_end(self, batch, logs=None)
在训练/测试/预测批次结束时调用。在此方法中,logs 是包含指标结果的字典。
周期级方法(仅训练)
on_epoch_begin(self, epoch, logs=None)
在训练期间周期开始时调用。
on_epoch_end(self, epoch, logs=None)
在训练期间周期开始时调用。

完整代码

"""
 * Created with PyCharm
 * 作者: 阿光
 * 日期: 2022/1/4
 * 时间: 10:02
 * 描述:
"""
import tensorflow as tf
from keras import Model
from tensorflow import keras
from tensorflow.keras.layers import *


def get_model():
    inputs = Input(shape=(784,))
    outputs = Dense(1)(inputs)
    model = Model(inputs, outputs)
    model.compile(
        optimizer=keras.optimizers.RMSprop(learning_rate=0.1),
        loss='mean_squared_error',
        metrics=['mean_absolute_error']
    )
    return model


(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
x_train = x_train.reshape(-1, 784).astype('float32') / 255.0
x_test = x_test.reshape(-1, 784).astype('float32') / 255.0

x_train = x_train[:1000]
y_train = y_train[:1000]
x_test = x_test[:1000]
y_test = y_test[:1000]


class CustomCallback(keras.callbacks.Callback):
    def on_train_begin(self, logs=None):
        keys = list(logs.keys())
        print("Starting training; got log keys: {}".format(keys))

    def on_train_end(self, logs=None):
        keys = list(logs.keys())
        print("Stop training; got log keys: {}".format(keys))

    def on_epoch_begin(self, epoch, logs=None):
        keys = list(logs.keys())
        print("Start epoch {} of training; got log keys: {}".format(epoch, keys))

    def on_epoch_end(self, epoch, logs=None):
        keys = list(logs.keys())
        print("End epoch {} of training; got log keys: {}".format(epoch, keys))

    def on_test_begin(self, logs=None):
        keys = list(logs.keys())
        print("Start testing; got log keys: {}".format(keys))

    def on_test_end(self, logs=None):
        keys = list(logs.keys())
        print("Stop testing; got log keys: {}".format(keys))

    def on_predict_begin(self, logs=None):
        keys = list(logs.keys())
        print("Start predicting; got log keys: {}".format(keys))

    def on_predict_end(self, logs=None):
        keys = list(logs.keys())
        print("Stop predicting; got log keys: {}".format(keys))

    def on_train_batch_begin(self, batch, logs=None):
        keys = list(logs.keys())
        print("...Training: start of batch {}; got log keys: {}".format(batch, keys))

    def on_train_batch_end(self, batch, logs=None):
        keys = list(logs.keys())
        print("...Training: end of batch {}; got log keys: {}".format(batch, keys))

    def on_test_batch_begin(self, batch, logs=None):
        keys = list(logs.keys())
        print("...Evaluating: start of batch {}; got log keys: {}".format(batch, keys))

    def on_test_batch_end(self, batch, logs=None):
        keys = list(logs.keys())
        print("...Evaluating: end of batch {}; got log keys: {}".format(batch, keys))

    def on_predict_batch_begin(self, batch, logs=None):
        keys = list(logs.keys())
        print("...Predicting: start of batch {}; got log keys: {}".format(batch, keys))

    def on_predict_batch_end(self, batch, logs=None):
        keys = list(logs.keys())
        print("...Predicting: end of batch {}; got log keys: {}".format(batch, keys))


model = get_model()
model.fit(
    x_train,
    y_train,
    batch_size=128,
    epochs=1,
    verbose=0,
    validation_split=0.5,
    callbacks=[CustomCallback()],
)

res = model.evaluate(
    x_test,
    y_test,
    batch_size=128,
    verbose=0,
    callbacks=[CustomCallback()]
)

res = model.predict(x_test,
                    batch_size=128,
                    callbacks=[CustomCallback()])

猜你喜欢

转载自blog.csdn.net/m0_47256162/article/details/122297340