模型变量保存与恢复tensorflow.train.Checkpoint

参考:tf公众号,【tf2.0】

注:Checkpoint只用于保存模型的参数,不保存模型的计算过程,因此一般用于在具有模型源代码的情况下恢复之前训练好的模型参数。如果需要导出模型(无需源代码也能运行模型),参考 SaveModel

个人经验:当前1.14版本是【tf.train.Saver】。一般保存非模型参数都是用python内置的pickle模块,但是若你想用其保存 model.variables,是无法完成的。之前在用pytorch的时候保存时都要进行序列化,tf也是一样,但是pickle不能将其序列化

类方法:save,restore
可以保存如下对象:
tf.keras.optimizer
tf.keras.Layer
tf.keras.Model
tf.Variable

首先声明一个Checkpoint

checkpoint = tf.train.Checkpoint(model=model)

这个类的初始化参数比较特殊,是一个**kwargs,一系列的键值对,键名可以随便取,键值则是需要保存到对象
例子: 保存tf.keras.Model的实例模型model和tf.optimizer的优化器optimizer

checkpoint = tf.train.Checkpoint(Mymodel=model, Myoptimizer=optimizer)

这里的Mymodel是我们随意的键名取值,在恢复模型参数的时候还要使用这一键名

保存参数

save_path_with_perfix:保存文件地址 + 前缀

checkpoint.save(save_path_with_perfix)

例子:
在源代码目录建立一个名为 save 的文件夹并调用一次 checkpoint.save(’./save/model.ckpt’) ,我们就可以在可以在 save 目录下发现名为 checkpoint 、 model.ckpt-1.index 、 model.ckpt-1.data-00000-of-00001 的三个文件,这些文件就记录了变量信息。checkpoint.save()
方法可以运行多次,每运行一次都会得到一个.index 文件和.data 文件,序号依次累加。(重要)

载入参数

当在其他地方需要为模型重新载入之前保存的参数时,需要再次实例化一个 checkpoint,同时保持键名的一致。再调用 checkpoint 的 restore 方法。就像下面这样:

1model_to_be_restored = MyModel()                                        # 待恢复参数的同一模型
2checkpoint = tf.train.Checkpoint(myAwesomeModel=model_to_be_restored)   # 键名保持为“myAwesomeModel”
3checkpoint.restore(save_path_with_prefix_and_index)

即可恢复模型变量。 save_path_with_prefix_and_index 是之前保存的文件的目录 + 前缀 + 编号。例如,调用 checkpoint.restore(’./save/model.ckpt-1’) 就可以载入前缀为 model.ckpt ,序号为 1 的文件来恢复模型。

当保存了多个文件时,我们往往想载入最近的一个。可以使用 tf.train.latest_checkpoint(save_path) 这个辅助函数返回目录下最近一次 checkpoint 的文件名。例如如果 save 目录下有 model.ckpt-1.indexmodel.ckpt-10.index 的 10 个保存文件, tf.train.latest_checkpoint(’./save’) 即返回 ./save/model.ckpt-10

总体而言,恢复与保存变量的典型代码框架如下:

1# train.py 模型训练阶段
2
3model = MyModel()
4# 实例化Checkpoint,指定保存对象为model(如果需要保存Optimizer的参数也可加入)
5checkpoint = tf.train.Checkpoint(myModel=model)
6# ...(模型训练代码)
7# 模型训练完毕后将参数保存到文件(也可以在模型训练过程中每隔一段时间就保存一次)
8checkpoint.save('./save/model.ckpt')
1# test.py 模型使用阶段
2
3model = MyModel()
4checkpoint = tf.train.Checkpoint(myModel=model)             # 实例化Checkpoint,指定恢复对象为model
5checkpoint.restore(tf.train.latest_checkpoint('./save'))    # 从文件恢复模型参数
6# 模型使用代码

注解

tf.train.Checkpoint 与以前版本常用的 tf.train.Saver 相比,强大之处在于其支持在 Eager Execution 下 “延迟” 恢复变量。具体而言,当调用了 checkpoint.restore() ,但模型中的变量还没有被建立的时候,Checkpoint 可以等到变量被建立的时候再进行数值的恢复。Eager Execution 下,模型中各个层的初始化和变量的建立是在模型第一次被调用的时候才进行的(好处在于可以根据输入的张量形状而自动确定变量形状,无需手动指定)。这意味着当模型刚刚被实例化的时候,其实里面还一个变量都没有,这时候使用以往的方式去恢复变量数值是一定会报错的。比如,你可以试试在 train.py 调用 tf.keras.Model 的 save_weight() 方法保存 model 的参数,并在 test.py 中实例化 model 后立即调用 load_weight() 方法,就会出错,只有当调用了一遍 model 之后再运行 load_weight() 方法才能得到正确的结果。可见, tf.train.Checkpoint 在这种情况下可以给我们带来相当大的便利。另外, tf.train.Checkpoint 同时也支持 Graph Execution 模式。

最后bb一句,整个工程的参数最好还是使用oython自带的argparse,docopt让人头大。

发布了48 篇原创文章 · 获赞 9 · 访问量 1万+

猜你喜欢

转载自blog.csdn.net/NewDreamstyle/article/details/102690785
今日推荐