Checkpoint 只保存模型的参数,不保存模型的计算过程,因此一般用于在具有模型源代码的时候恢复之前训练好的模型参数。如果需要导出模型(无需源代码也能运行模型),可以使用SavedModel。
tf.train.Checkpoint
:变量的保存与恢复
首先声明一个 Checkpoint
checkpoint = tf.train.Checkpoint(myAwesomeModel=model, myAwesomeOptimizer=optimizer)
当模型训练完成需要保存的时候,使用
checkpoint.save(save_path_with_prefix)
当在其他地方需要为模型重新载入之前保存的参数时,需要再次实例化一个 checkpoint,同时保持键名的一致。再调用 checkpoint 的 restore 方法
model_to_be_restored = MyModel() # 待恢复参数的同一模型
checkpoint = tf.train.Checkpoint(myAwesomeModel=model_to_be_restored) # 键名保持为“myAwesomeModel”
checkpoint.restore(save_path_with_prefix_and_index)
tf.train.CheckpointManager
删除旧的 Checkpoint 以及自定义文件编号
在模型的训练过程中,我们往往每隔一定步数保存一个 Checkpoint 并进行编号。不过很多时候我们会有这样的需求:
-
在长时间的训练后,程序会保存大量的 Checkpoint,但我们只想保留最后的几个 Checkpoint;
-
Checkpoint 默认从 1 开始编号,每次累加 1,但我们可能希望使用别的编号方式(例如使用当前 Batch 的编号作为文件编号)。
checkpoint = tf.train.Checkpoint(model=model)
manager = tf.train.CheckpointManager(checkpoint, directory='./save', checkpoint_name='model.ckpt', max_to_keep=k)
directory
参数为文件保存的路径, checkpoint_name
为文件名前缀(不提供则默认为 ckpt
), max_to_keep
为保留的 Checkpoint 数目。
在需要保存模型的时候,我们直接使用 manager.save()
即可。如果我们希望自行指定保存的 Checkpoint 的编号,则可以在保存时加入 checkpoint_number
参数