tensorflow2.0常用回调函数小结

经查看官方文档将常用回调函数做以下小结,目的是了解每个回调函数的作用与参数用法。

上图是tf2.0的全部回调函数,在这里介绍常用的4个回调函数:EarlyStopping,tensorboard,ModelCheckpoint,history。

1、tf.keras.callbacks.EarlyStopping

目的/作用:当监控的值停止变化时,提前结束训练。

定义:

tf.keras.callbacks.EarlyStopping(
    monitor='val_loss', min_delta=0, patience=0, verbose=0, mode='auto',
    baseline=None, restore_best_weights=False
)

由上面的代码段可以得知,当未自己手动设置monitor时,默认监控的是验证集的loss(val_loss)。

常用参数介绍:

monitor:监控的值。
min_delta:监视值的最小变化,即,绝对变化小于min_delta的情况,将视为没有变化
patience:在多少个epoch,监控的值没有变化后,将停止训练。(也就是连续多少个epoch,监控值的绝对变化小于min_delta)

示例:

callback = tf.keras.callbacks.EarlyStopping(monitor='val_loss', patience=3)
# This callback will stop the training when there is no improvement in
# the validation loss for three consecutive epochs.
model.fit(data, labels, epochs=100, callbacks=[callback],
    validation_data=(val_data, val_labels))

2、tf.keras.callbacks.TensorBoard

作用:tensorflow的可视化工具

定义:

tf.keras.callbacks.TensorBoard(
    log_dir='logs', histogram_freq=0, write_graph=True, write_images=False,
    update_freq='epoch', profile_batch=2, embeddings_freq=0,
    embeddings_metadata=None, **kwargs
)

常用参数:

log_dir:将TensorBoard解析的日志文件保存到的目录路径。

其余用到再补充

示例:

logdir = os.path.join("callbacks")
if not os.path.exists(logdir):
    os.mkdir(logdir)

callbacks = [
    keras.callbacks.TensorBoard(logdir),]

history = model.fit(x_train_scaled, y_train, epochs=100,
                    validation_data=(x_valid_scaled, y_valid),
                    callbacks = callbacks)

tensorboard显示:

3、tf.keras.callbacks.ModelCheckpoint

作用:在每一次epoch后保存模型

定义:

tf.keras.callbacks.ModelCheckpoint(
    filepath, monitor='val_loss', verbose=0, save_best_only=False,
    save_weights_only=False, mode='auto', save_freq='epoch', **kwargs
)

常用参数:

filepath:字符串,保存模型文件的路径。

示例:

logdir = os.path.join("callbacks")
output_model_file = os.path.join(logdir,
                                 "fashion_mnist_model.h5")

callbacks = [
    keras.callbacks.ModelCheckpoint(output_model_file,
                                    save_best_only = True),#保存最好的模型,默认保存最近的
]

history = model.fit(x_train_scaled, y_train, epochs=100,
                    validation_data=(x_valid_scaled, y_valid),
                    callbacks = callbacks)

4、tf.keras.callbacks.History

这个回调函数会自动应用到每一个keras模型,History对象通过模型的fit方法得到返回。

history = model.fit(x_train_scaled, y_train, epochs=100,
                    validation_data=(x_valid_scaled, y_valid),
                    callbacks = callbacks)
原创文章 46 获赞 49 访问量 2184

猜你喜欢

转载自blog.csdn.net/qq_41660119/article/details/105832293
今日推荐