模型的保存和恢复
1. 模型整体保存
模型整体保存指保存模型权重,模型配置和优化器配置
import tensorflow as tf
# 保存模型
model.save("model.h5")
# 加载模型
new_model = tf.keras.models.load_model("model.h5")
2. 仅保存模型架构
json_config = model.to_json()
# 保存json文件
with open("model_config.json", 'wt') as f:
f.write(json_config)
# 读取json文件
with open("model_config.json", 'rt') as f:
json_config = f.read(json_config)
new_model = tf.keras.models.model_from_json(json_config)
3. 仅保存模型权重
# 保存权重
model.save_weights("model_weights.h5")
# 恢复权重
model.load_weights("model_weights.h5")
4. 在训练过程中保存检查点
ModelCheckpoint(filepath, monitor=‘val_loss’, verbose=0, save_best_only=False, save_weights_only=False, mode=‘auto’, period=1)
- filepath:字符串,保存模型的路径
- monitor:需要监视的值
- verbose:信息展示模式,0或1
- save_best_only:当设置为True时,监测值有改进时才会保存当前的模型
- mode:‘auto’,‘min’,‘max’之一,在save_best_only=True时决定性能最佳模型的评判准则。
例如,当监测值为val_acc时,模式应为max,当监测值为val_loss时,模式应为min。在auto模式下,评价准则由被监测值的名字自动推断。 - save_weights_only:若设置为True,则只保存模型权重,否则将保存整个模型(包括模型结构,配置信息等)
- period:CheckPoint之间的间隔的epoch数
path = 'best_model.h5'
checkpointer = tf.keras.callbacks.ModelCheckpoint(filepath=path,
monitor='loss',
verbose=1,
save_best_only=True)
model.compile(optimizer = 'adam',
loss = 'sparse_categorical_crossentropy',
metrics=['accuracy'])
model.fit(train_dataset, epochs=10, callbacks=[checkpointer])
5. 在自定义训练中保存模型
# 存储的文件夹路径
cp_dir = "train_cp"
# 要保存的检查点的前缀
cp_prefix = os.path.join(cp_dir, "ckpt")
# Checkpoint的参数为要保存的对象
check_point = tf.train.Checkpoint(optimizer = optimizer,
model = model)
def train(model, train_dataset, test_dataset, epochs):
# 恢复最新的检查点
# check_point.restore(tf.train.latest_checkpoint(cp_dir))
for epoch in range(epochs):
# 训练
# 测试
# 重置评估指标
# 保存检查点
check_point.save(file_prefix = cp_prefix)