模型的保存和恢复

模型的保存和恢复

1. 模型整体保存

模型整体保存指保存模型权重,模型配置和优化器配置

import tensorflow as tf
# 保存模型
model.save("model.h5")
# 加载模型
new_model = tf.keras.models.load_model("model.h5")

2. 仅保存模型架构

json_config = model.to_json()
# 保存json文件
with open("model_config.json", 'wt') as f:
    f.write(json_config)
# 读取json文件
with open("model_config.json", 'rt') as f:
    json_config = f.read(json_config)
new_model = tf.keras.models.model_from_json(json_config)

3. 仅保存模型权重

# 保存权重
model.save_weights("model_weights.h5")
# 恢复权重
model.load_weights("model_weights.h5")

4. 在训练过程中保存检查点

ModelCheckpoint(filepath, monitor=‘val_loss’, verbose=0, save_best_only=False, save_weights_only=False, mode=‘auto’, period=1)

  • filepath:字符串,保存模型的路径
  • monitor:需要监视的值
  • verbose:信息展示模式,0或1
  • save_best_only:当设置为True时,监测值有改进时才会保存当前的模型
  • mode:‘auto’,‘min’,‘max’之一,在save_best_only=True时决定性能最佳模型的评判准则。
    例如,当监测值为val_acc时,模式应为max,当监测值为val_loss时,模式应为min。在auto模式下,评价准则由被监测值的名字自动推断。
  • save_weights_only:若设置为True,则只保存模型权重,否则将保存整个模型(包括模型结构,配置信息等)
  • period:CheckPoint之间的间隔的epoch数
path = 'best_model.h5'
checkpointer = tf.keras.callbacks.ModelCheckpoint(filepath=path, 
                                                  monitor='loss', 
                                                  verbose=1, 
                                                  save_best_only=True)

model.compile(optimizer = 'adam',
              loss = 'sparse_categorical_crossentropy',
              metrics=['accuracy'])

model.fit(train_dataset, epochs=10, callbacks=[checkpointer])

5. 在自定义训练中保存模型

# 存储的文件夹路径
cp_dir = "train_cp"
# 要保存的检查点的前缀
cp_prefix = os.path.join(cp_dir, "ckpt")

# Checkpoint的参数为要保存的对象
check_point = tf.train.Checkpoint(optimizer = optimizer,
                                  model = model)
def train(model, train_dataset, test_dataset, epochs):
    # 恢复最新的检查点
    # check_point.restore(tf.train.latest_checkpoint(cp_dir))
    for epoch in range(epochs):
        # 训练
        # 测试
        # 重置评估指标
        # 保存检查点
        check_point.save(file_prefix = cp_prefix)
发布了14 篇原创文章 · 获赞 1 · 访问量 573

猜你喜欢

转载自blog.csdn.net/lan_faster/article/details/103612683