keras回调函数的使用

概述

在训练时候,我们希望在训练中有所输出和判断,而不是一直到训练结束才能实现交互,那么回调函数就是你最好的选择。本篇博文针对回调函数的功能,种类以及代码使用进行讲解

功能与种类

回调函数的用法例如下所示

  • 模型检查点(model checkpointing):在训练过程中的不同时间点保存模型的当前权重。
  • 提前终止(early stopping):如果验证损失不再改善,则中断训练(当然,同时保存在训练过程中得到的最佳模型)。
  • 在训练过程中动态调节某些参数值:比如优化器的学习率。
  • 在训练过程中记录训练指标和验证指标,或将模型学到的表示可视化(这些表示也在不断更新):如Keras 进度条就是一个回调函数

keras.callbacks 模块包含许多内置的回调函数,下面列出了其中一些

	keras.callbacks.ModelCheckpoint
	keras.callbacks.EarlyStopping
	keras.callbacks.LearningRateScheduler
	keras.callbacks.ReduceLROnPlateau
	keras.callbacks.CSVLogger

ModelCheckpoint 与EarlyStopping 回调函数

import keras
callbacks_list = [
	keras.callbacks.EarlyStopping(#如果不再改善,就中断训练
	monitor='acc',
	patience=1,#如果精度在多于一轮的时间(即两轮)内不再改善,中断训练
),
keras.callbacks.ModelCheckpoint(
	filepath='my_model.h5',
	monitor='val_loss',#这两个参数的含义是,如果val_loss 没有改善,那么不需要覆盖模型文件。这就可以始终保存在训练过程中见到的最佳模型
	save_best_only=True,
)
]
model.compile(optimizer='rmsprop',
			  loss='binary_crossentropy',
			  metrics=['acc'])#你监控精度,所以它应该是模型指标的一部分
model.fit(x, y,
		  epochs=10,
		  batch_size=32,
		  callbacks=callbacks_list,#前面的设置在这里使用
		  validation_data=(x_val, y_val))

ReduceLROnPlateau 回调函数

如果损失不再改善,就降低学习率

callbacks_list = [
	keras.callbacks.ReduceLROnPlateau(
		monitor='val_loss'
		factor=0.1,#触发时将学习率除以10
		patience=10,#如果验证损失在10 轮内都没有改善,那么就触发这个回调函数
	)
]
model.fit(x, y,
		 epochs=10,
		 batch_size=32,
		 callbacks=callbacks_list,
		 validation_data=(x_val, y_val))

编写自己的回调函数

下面是一个自定义回调函数的简单示例,它可以在每轮结束后将模型每层的激活保存到硬盘(格式为Numpy 数组),这个激活是对验证集的第一个样本计算得到的。

import keras
import numpy as np
class ActivationLogger(keras.callbacks.Callback):
	def set_model(self, model):
		self.model = model#在训练之前由父模型调用,告诉回调函数是哪个模型在调用它
		layer_outputs = [layer.output for layer in model.layers]
		#模型实例,返回每层的激活
		self.activations_model = keras.models.Model(model.input,
													layer_outputs)
	def on_epoch_end(self, epoch, logs=None):
		if self.validation_data is None:
			raise RuntimeError('Requires validation_data.')
		#获取验证数据的第一个输入样本
		validation_sample = self.validation_data[0][0:1]
		activations = self.activations_model.predict(validation_sample)
		f = open('activations_at_epoch_' + str(epoch) + '.npz', 'w')
		np.savez(f, activations)
		f.close()
发布了25 篇原创文章 · 获赞 41 · 访问量 1万+

猜你喜欢

转载自blog.csdn.net/qq_32796253/article/details/89190206