概述
在训练时候,我们希望在训练中有所输出和判断,而不是一直到训练结束才能实现交互,那么回调函数就是你最好的选择。本篇博文针对回调函数的功能,种类以及代码使用进行讲解
功能与种类
回调函数的用法例如下所示
- 模型检查点(model checkpointing):在训练过程中的不同时间点保存模型的当前权重。
- 提前终止(early stopping):如果验证损失不再改善,则中断训练(当然,同时保存在训练过程中得到的最佳模型)。
- 在训练过程中动态调节某些参数值:比如优化器的学习率。
- 在训练过程中记录训练指标和验证指标,或将模型学到的表示可视化(这些表示也在不断更新):如Keras 进度条就是一个回调函数
keras.callbacks 模块包含许多内置的回调函数,下面列出了其中一些
keras.callbacks.ModelCheckpoint
keras.callbacks.EarlyStopping
keras.callbacks.LearningRateScheduler
keras.callbacks.ReduceLROnPlateau
keras.callbacks.CSVLogger
ModelCheckpoint 与EarlyStopping 回调函数
import keras
callbacks_list = [
keras.callbacks.EarlyStopping(#如果不再改善,就中断训练
monitor='acc',
patience=1,#如果精度在多于一轮的时间(即两轮)内不再改善,中断训练
),
keras.callbacks.ModelCheckpoint(
filepath='my_model.h5',
monitor='val_loss',#这两个参数的含义是,如果val_loss 没有改善,那么不需要覆盖模型文件。这就可以始终保存在训练过程中见到的最佳模型
save_best_only=True,
)
]
model.compile(optimizer='rmsprop',
loss='binary_crossentropy',
metrics=['acc'])#你监控精度,所以它应该是模型指标的一部分
model.fit(x, y,
epochs=10,
batch_size=32,
callbacks=callbacks_list,#前面的设置在这里使用
validation_data=(x_val, y_val))
ReduceLROnPlateau 回调函数
如果损失不再改善,就降低学习率
callbacks_list = [
keras.callbacks.ReduceLROnPlateau(
monitor='val_loss'
factor=0.1,#触发时将学习率除以10
patience=10,#如果验证损失在10 轮内都没有改善,那么就触发这个回调函数
)
]
model.fit(x, y,
epochs=10,
batch_size=32,
callbacks=callbacks_list,
validation_data=(x_val, y_val))
编写自己的回调函数
下面是一个自定义回调函数的简单示例,它可以在每轮结束后将模型每层的激活保存到硬盘(格式为Numpy 数组),这个激活是对验证集的第一个样本计算得到的。
import keras
import numpy as np
class ActivationLogger(keras.callbacks.Callback):
def set_model(self, model):
self.model = model#在训练之前由父模型调用,告诉回调函数是哪个模型在调用它
layer_outputs = [layer.output for layer in model.layers]
#模型实例,返回每层的激活
self.activations_model = keras.models.Model(model.input,
layer_outputs)
def on_epoch_end(self, epoch, logs=None):
if self.validation_data is None:
raise RuntimeError('Requires validation_data.')
#获取验证数据的第一个输入样本
validation_sample = self.validation_data[0][0:1]
activations = self.activations_model.predict(validation_sample)
f = open('activations_at_epoch_' + str(epoch) + '.npz', 'w')
np.savez(f, activations)
f.close()