tensorflow 之 模型的保存(save)、恢复/加载(restore)

1、什么是 tensorflow 模型

当你训练完一个神经网络,你可能会想要保存这个网络,以便将来拿来使用或直接用于其他数据的 deploy,

tensorflow 模型包括:已训练并优化的权重参数,网络结构和 graph。

tensorflow 模型文件包括两大块:

  • meta graph :序列化缓冲文件,保存完整的网络结构,graph ,即 all variables, operations, collections 等,扩展名是 .meta
  • checkpoint file:二进制文件,包括 weights, biases, gradients 和 all the other variables,扩展名为 .ckpt 。但是从0.11版本开始,就不是单独的 .ckpt 文件了,而是有两个文件:
>>mymodel.data-00000-of-00001 #包括训练变量,可从这个文件开始继续训练
>>mymodel.index 

此外,checkpoint 保存最近一次的模型。所以 tensorflow 共包含以下四个文件

2、保存 tensorflow 模型

有时候不知道哪个模型是最优的,故需要保存多个模型。默认情况下保存最近的5个模型。

tensorflow 中的变量只在会话 session 中存在,所以需要在 saver 对象上调用 save 方法,将模型保存在会话中。

#模型的保存
import tensorflow as tf
import os

w1 = tf.Variable(tf.random_normal(shape=[2]), name='w1')
w2 = tf.Variable(tf.random_normal(shape=[5]), name='w2')
saver = tf.train.Saver() #可指定需要存储的tensor,不指定则全部保存
with tf.Session() as sess:    
    #sess = tf.Session()
    sess.run(tf.global_variables_initializer())
    #创建保存模型的文件夹
    if not os.path.exists('my_model'):
        os.mkdir('./my_model')
        saver.save(sess, './my_model/my_test_model')

#可通过设置saver.save()的参数指定保存哪一步的模型
saver.save(sess, './my_model/my_test_model', global_step=1000) #保存1000步的模型

# This will save following files in Tensorflow v >= 0.11
# my_test_model.data-00000-of-00001
# my_test_model.index
# my_test_model.meta
# checkpoint

1000步的模型,会在 my_test_model 后 append ‘-1000’ 

.meta 保存的是网络结构,训练过程中不改变网络结果,保存一次即可,可使用如下语句:

saver.save(sess, './my_model/my_test_model', global_step=step, write_meta_graph=False)

如果想要每2小时保存一次模型,且保存最近的4个模型,可使用如下语句:

#saves a model every 2 hours and maximum 4 latest models are saved.
saver = tf.train.Saver(max_to_keep=4, keep_checkpoint_every_n_hours=2)

如果不保存全部的 tensor ,可通过指定 variables/collections 来保存,使用如下语句:

#将需要保存的变量以列表形式添加在saver中?自己的理解~确实是这个语句
saver = tf.train.Saver([w1, w2])

3、加载预训练模型

如果需要用别人训练好的模型做微调,需要以下两步:

  • 使用如下语句加载网络结构:
saver = tf.train.import_meta_graph('./my_model/my_test_model.meta')
  • 使用如下语句加载参数:
import tensorflow as tf
with tf.Session() as sess:
    new_saver = tf.train.import_meta_graph('./my_model/my_test_model.meta') #加载网络结构
    new_saver.restore(sess, tf.train.latest_checkpoint('./my_model')) #加载最近一次保存的ckpt
    #初始化参数
    sess.run(tf.global_variables_initializer())
    print(sess.run('w1:0'))
    #返回:INFO:tensorflow:Restoring parameters from ./my_model\my_test_model
      [ 0.35064858  2.87996149]

参考:https://cv-tricks.com/tensorflow-tutorial/save-restore-tensorflow-models-quick-complete-tutorial/

https://blog.csdn.net/liuxiao214/article/details/79048136

猜你喜欢

转载自blog.csdn.net/weixin_42338058/article/details/84310969