tensorflow学习日记03

变量保存与导入

tensorflow内置的参数导出和导入基本用法,用于保存训练好的模型参数

import tensorflow as tf
"""
变量声明,运算声明 例:w = tf.get_variable(name="vari_name", shape=[], dtype=tf.float32)
初始化op声明
"""
#创建saver对象,它添加了一些op用来save和restore模型参数
saver = tf.train.Saver()

with tf.Session() as sess:
    sess.run(init_op)
    #训练模型。。。
    #使用saver提供的简便方法去调用 save op
    saver.save(sess, "save_path/file_name.ckpt") #file_name.ckpt如果不存在的话,会自动创建
#后缀可加可不加

现在,训练好的模型参数已经存储好了,我们来看一下怎么调用训练好的参数。变量保存的时候,保存的是变量名:value,键值对。restore的时候,也是根据key-value来进行的

import tensorflow as tf
"""
变量声明,运算声明
初始化op声明
"""
#创建saver 对象
saver = tf.train.Saver()

with tf.Session() as sess:
    sess.run(init_op)#在这里,可以执行这个语句,也可以不执行,即使执行了,初始化的值也会被restore的值给override
    saver.restore(sess, "save_path/file_name.ckpt-???") 
    #会将已经保存的变量值resotre到变量中,自己看好要restore哪步的 

如何restore变量的子集,然后使用初始化op初始化其他变量

#想要实现这个功能的话,必须从Saver的构造函数下手
saver=tf.train.Saver([sub_set])
init = tf.initialize_all_variables()
with tf.Session() as sess:
  #这样你就可以使用restore的变量替换掉初始化的变量的值,而其它初始化的值不受影响
  sess.run(init)
  if restor_from_checkpoint:
      saver.restore(sess,"file.ckpt")
  # train
  saver.save(sess,"file.ckpt")

猜你喜欢

转载自blog.csdn.net/wyisfish/article/details/80711367
今日推荐