tensorflow2 tf.train.Checkpoint : save and restore variables

Checkpoint only saves the parameters of the model, not the calculation process of the model, so it is generally used to restore the previously trained model parameters when the source code of the model is available. If you need to export the model (you can run the model without source code), you can use SavedModel.

tf.train.Checkpoint : Save and restore variables

First declare a Checkpoint

checkpoint = tf.train.Checkpoint(myAwesomeModel=model, myAwesomeOptimizer=optimizer)

When the model training is completed and needs to be saved, use

checkpoint.save(save_path_with_prefix)

When you need to reload the previously saved parameters for the model elsewhere, you need to instantiate a checkpoint again while keeping the key name consistent. Then call the restore method of checkpoint

model_to_be_restored = MyModel()                                        # 待恢复参数的同一模型
checkpoint = tf.train.Checkpoint(myAwesomeModel=model_to_be_restored)   # 键名保持为“myAwesomeModel”
checkpoint.restore(save_path_with_prefix_and_index)

 tf.train.CheckpointManager Delete old Checkpoint and custom file numbers

During the training process of the model, we often save and number a Checkpoint every certain number of steps. But many times we will have such needs:

  • After a long period of training, the program will save a large number of Checkpoints, but we only want to keep the last few Checkpoints;

  • Checkpoint starts numbering from 1 by default and adds 1 each time, but we may want to use other numbering methods (for example, use the number of the current batch as the file number).

checkpoint = tf.train.Checkpoint(model=model)
manager = tf.train.CheckpointManager(checkpoint, directory='./save', checkpoint_name='model.ckpt', max_to_keep=k)

directory The parameter is the path where the file is saved,  checkpoint_name the prefix of the file name (if not provided, it is the default  ckpt ),  max_to_keep and the number of checkpoints reserved.

When we need to save the model, we  manager.save() can use it directly. If we want to specify the number of the saved Checkpoint by ourselves, we can add  checkpoint_number parameters when saving

Guess you like

Origin blog.csdn.net/qq_40107571/article/details/131367800