Checkpoint only saves the parameters of the model, not the calculation process of the model, so it is generally used to restore the previously trained model parameters when the source code of the model is available. If you need to export the model (you can run the model without source code), you can use SavedModel.
tf.train.Checkpoint
: Save and restore variables
First declare a Checkpoint
checkpoint = tf.train.Checkpoint(myAwesomeModel=model, myAwesomeOptimizer=optimizer)
When the model training is completed and needs to be saved, use
checkpoint.save(save_path_with_prefix)
When you need to reload the previously saved parameters for the model elsewhere, you need to instantiate a checkpoint again while keeping the key name consistent. Then call the restore method of checkpoint
model_to_be_restored = MyModel() # 待恢复参数的同一模型
checkpoint = tf.train.Checkpoint(myAwesomeModel=model_to_be_restored) # 键名保持为“myAwesomeModel”
checkpoint.restore(save_path_with_prefix_and_index)
tf.train.CheckpointManager
Delete old Checkpoint and custom file numbers
During the training process of the model, we often save and number a Checkpoint every certain number of steps. But many times we will have such needs:
-
After a long period of training, the program will save a large number of Checkpoints, but we only want to keep the last few Checkpoints;
-
Checkpoint starts numbering from 1 by default and adds 1 each time, but we may want to use other numbering methods (for example, use the number of the current batch as the file number).
checkpoint = tf.train.Checkpoint(model=model)
manager = tf.train.CheckpointManager(checkpoint, directory='./save', checkpoint_name='model.ckpt', max_to_keep=k)
directory
The parameter is the path where the file is saved, checkpoint_name
the prefix of the file name (if not provided, it is the default ckpt
), max_to_keep
and the number of checkpoints reserved.
When we need to save the model, we manager.save()
can use it directly. If we want to specify the number of the saved Checkpoint by ourselves, we can add checkpoint_number
parameters when saving