tf.train.Saver


class tf.train.Saver

保存和恢复变量

最简单的保存和恢复模型的方法是使用tf.train.Saver 对象。构造器给graph 的所有变量,或是定义在列表里的变量,添加save 和 restore ops。saver 对象提供了方法来运行这些ops,定义检查点文件的读写路径。

检查点是专门格式的二进制文件,将变量name 映射到 tensor value。检查checkpoin 内容最好的方法是使用Saver 加载它。

Savers 可以使用提供的计数器自动计数checkpoint 文件名。这可以是你在训练一个模型时,在不同的步骤维持多个checkpoint。例如你可以使用 training step number 计数checkpoint 文件名。为了避免填满硬盘,savers 自动管理checkpoint 文件。例如,你可以最多维持N个最近的文件,或者没训练N小时保存一个checkpoint.

通过传递一个值给可选参数 global_step ,你可以编号checkpoint 名字。

[python] view plain copy
print ?
  1. saver.save(sess, 'my-model', global_step=0) ==>filename: 'my-model-0'  
  2. saver.save(sess, 'my-model', global_step=1000) ==>filename: 'my-model-1000'  
saver.save(sess, 'my-model', global_step=0) ==>filename: 'my-model-0'
saver.save(sess, 'my-model', global_step=1000) ==>filename: 'my-model-1000'
另外,Saver() 构造器可选的参数可以让你控制硬盘上 checkpoint 文件的数量。

  • max_to_keep:  表明保存的最大checkpoint 文件数。当一个新文件创建的时候,旧文件就会被删掉。如果值为None或0,表示保存所有的checkpoint 文件。默认值为5(也就是说,保存最近的5个checkpoint 文件)。
  • keep_checkpoint_every_n_hour:  除了保存最近的max_to_keep checkpoint 文件,你还可能想每训练N小时保存一个checkpoint 文件。这将是非常有用的,如果你想分析一个模型在很长的一段训练时间内是怎么改变的。例如,设置 keep_checkpoint_every_n_hour=2 确保没训练2个小时保存一个checkpoint 文件。默认值10000小时无法看到特征。
注意,你仍然必须调用save() 方法去保存模型。传递这些参数给构造器并不会自动为你保存这些变量。

一个定期保存的训练程序如下这样:

[python] view plain copy
print ?
  1. #Create a saver  
  2. saver=tf.train.Saver(...variables...)  
  3. #Launch the graph and train, saving the model every 1,000 steps.  
  4. sess=tf.Session()  
  5. for step in xrange(1000000):  
  6.     sess.run(...training_op...)  
  7.     if step % 1000 ==0:  
  8.         #Append the step number to the checkpoint name:  
  9.         saver.save(sess,'my-model',global_step=step)  
#Create a saver
saver=tf.train.Saver(...variables...)
#Launch the graph and train, saving the model every 1,000 steps.
sess=tf.Session()
for step in xrange(1000000):
    sess.run(...training_op...)
    if step % 1000 ==0:
        #Append the step number to the checkpoint name:
        saver.save(sess,'my-model',global_step=step)
除了checkpoint 文件之外,savers 还在硬盘上保存了一个协议缓存,存储最近的checkpoint 列表。这用于管理 被编号的checkpoint 文件,并且通过latest_checkpoint() 可以很容易找到最近的checkpoint 的路径。协议缓存存储在紧挨checkpoint 文件的名为 'checkpoint' 的文件中。

如果你创建了几个savers,你可以调用save() 指定协议缓存的文件名。

tf.train.Saver.__init__(var_list=None, reshape=False, shared=False, max_to_keep=5, keep_checkpoint_every_n_hour=10000.0, name=None, restore_sequentially=False, saver_def=None, builder=None)

创建一个Saver

构造器添加操作去保存和恢复变量。

var_list 指定了将要保存和恢复的变量。它可以传dict 或者list

  • 变量名字的dict: key 是将用来在checkpoint 文件中存储和恢复的变量的名称。
  • 变量的list:  变量的 op name
例如:
[python] view plain copy
print ?
  1. v1=tf.Variable(..., name='v1')  
  2. v2=tf.Variable(..., name='v2')  
  3.   
  4. # Pass the variables as a dict:  
  5. saver=tf.train..Saver({'v1':v1, 'v2':v2})  
  6. # Or pass them as a list  
  7. saver=tf.train..Saver([v1,v2])  
  8. # Passing a list is equivalent to passing a dict with the variable op names as keys:  
  9. saver=tf.train..Saver({v.op.name: v for v in [v1,v2]})  
v1=tf.Variable(..., name='v1')
v2=tf.Variable(..., name='v2')

# Pass the variables as a dict:
saver=tf.train..Saver({'v1':v1, 'v2':v2})
# Or pass them as a list
saver=tf.train..Saver([v1,v2])
# Passing a list is equivalent to passing a dict with the variable op names as keys:
saver=tf.train..Saver({v.op.name: v for v in [v1,v2]})
可选参数 reshape ,如果为True,允许从保存文件中恢复一个不同shape 的变量,但元素的数量和type一致。如果你reshap 了一个变量而又想从一个旧的文件中恢复,这是非常有用的。
可选参数 shared,如果为True,通知每个设备上共享的checkpoint.

tf.train.Saver.save(sess, save_path, global_step=None, latest_filename=None, meta_graph_suffix='meta', write_meta_graph=True)
保存变量
这个方法运行通过构造器添加的操作。它需要启动图的session。被保存的变量必须经过了初始化。
方法返回新建的checkpoint 文件的路径。路径可以直接传给restore() 进行调用。
参数:
  • sess:  用于保存变量的Session
  • save_path:  checkpoint 文件的路径。如果saver 是共享的,这是共享checkpoint 文件名的前缀。
  • global_step:  如果提供了global step number,将会追加到 save_path 后面去创建checkpoint 的文件名。可选参数可以是一个Tensor,一个name Tensor或integer Tensor.
返回值:
一个字符串:保存变量的路径。如果saver 是被共享的,字符串以'-?????-of-nnnnn' 结尾。'nnnnn' 是共享的数目。
保存变量
用tf.train.Saver() 创建一个Saver 来管理模型中的所有变量。
[python] view plain copy
print ?
  1. #!/usr/bin/env python  
  2. # coding=utf-8  
  3.   
  4. import os  
  5. import tensorflow as tf  
  6.   
  7. # Create some variables.  
  8. v1=tf.Variable([[1,1],[2,2],[3,3]],name="v1")  
  9. v2=tf.Variable([[4,4],[5,5],[6,7]],name="v2")  
  10. # Add an op to initialize the variables.  
  11. init_op=tf.initialize_all_variables()  
  12. # Add ops to save and restore all the variables.  
  13. saver=tf.train.Saver()  
  14. # Later, launch the model, initialize the variables, do some work, save the variables to disk.  
  15. with tf.Session() as sess:  
  16.     sess.run(init_op)  
  17.     # Do some work with the model.  
  18.     save_path=saver.save(sess,"/home/yhk/tmp/test/model.ckpt")  
  19.     print "Model saved in file: ", save_path  
#!/usr/bin/env python
# coding=utf-8

import os
import tensorflow as tf

# Create some variables.
v1=tf.Variable([[1,1],[2,2],[3,3]],name="v1")
v2=tf.Variable([[4,4],[5,5],[6,7]],name="v2")
# Add an op to initialize the variables.
init_op=tf.initialize_all_variables()
# Add ops to save and restore all the variables.
saver=tf.train.Saver()
# Later, launch the model, initialize the variables, do some work, save the variables to disk.
with tf.Session() as sess:
    sess.run(init_op)
    # Do some work with the model.
    save_path=saver.save(sess,"/home/yhk/tmp/test/model.ckpt")
    print "Model saved in file: ", save_path



如果你不给tf.train.Saver() 传入任何参数,那么server 将处理graph 中的所有变量。其中每一个变量都以变量创建时传入的名称被保存。

tf.train.Saver.restore(sess, save_path)
恢复之前保存的变量
这个方法运行构造器为恢复变量所添加的操作。它需要启动图的Session。恢复的变量不需要经过初始化,恢复作为初始化的一种方法。
save_path 参数是之前调用save() 的返回值,或调用 latest_checkpoint() 的返回值。
参数:
  • sess:  用于恢复参数的Session
  • save_path:  参数之前保存的路径

猜你喜欢

转载自blog.csdn.net/weixin_42052460/article/details/80715586
今日推荐