【tensorflow】模型存储和恢复

  tensorflow中,模型的存储和恢复使用tf.train.Saver类,模型存储使用该类的 save 方法。模型恢复使用restore 方法。

模型存储

  模型存储使用tf.train.Saver.save()方法。以saver.save(sess, 'model/model.ckpt')为例,在model路径下会有四个文件(如下图)

  • checkpoint 记录保存信息,通过它可以定位最新保存的模型;
  • *.meta 保存当前图结构;
  • *.index 保存当前参数名;
  • *.data 保存当前参数名。
import tensorflow as tf

a = tf.get_variable('a', shape=[3], initializer=tf.constant_initializer(1))
b = tf.get_variable('b', shape=[5], initializer=tf.constant_initializer(2))
saver = tf.train.Saver()

with tf.Session() as sess:
    sess.run(tf.initialize_all_variables())
    saver.save(sess, 'model/model.ckpt')
模型恢复

  模型恢复使用tf.train.Saver.restore() 方法。

```
import tensorflow as tf

a = tf.get_variable('a', shape=[3])
b = tf.get_variable('b', shape=[5])

saver = tf.train.Saver()
with tf.Session() as sess:
    saver.restore(sess, 'model/model.ckpt')
    print(sess.run(a)) # [ 1.  1.  1.]

    print(sess.run(b)) # [ 2.  2.  2.  2.  2.]


```
保存和恢复部分变量

  使用 save 方法存储模型时,若不指定参数,则 Saver 会处理图中所有的变量。每个变量都保存在创建变量时所传递的名称下。我们还可以对指定变量进行存储和恢复。示例如下:
- save

import tensorflow as tf

a = tf.get_variable('a', shape=[3], initializer=tf.constant_initializer(1))
b = tf.get_variable('b', shape=[5], initializer=tf.constant_initializer(2))
saver = tf.train.Saver({"a": a})

with tf.Session() as sess:
    sess.run(tf.initialize_all_variables())
    saver.save(sess, 'model/model.ckpt')
  • restore

import tensorflow as tf

a = tf.get_variable('a', shape=[3])
# b = tf.get_variable('b', shape=[5])

saver = tf.train.Saver()
with tf.Session() as sess:
    saver.restore(sess, 'model/model.ckpt')
    print(sess.run(a))
    # print(sess.run(b))

检查某个检查点的变量

  使用 inspect_checkpoint 库快速检查某个检查点的变量。

from tensorflow.python.tools import inspect_checkpoint as chkp

# 打印所有 tensors
chkp.print_tensors_in_checkpoint_file("model/model.ckpt", tensor_name='', all_tensors=True)

# tensor_name:  a
# [ 1.  1.  1.]
# tensor_name:  b
# [ 2.  2.  2.  2.  2.]

# 打印 tensor a
chkp.print_tensors_in_checkpoint_file("model/model.ckpt", tensor_name='a', all_tensors=False)

# tensor_name:  a
# [ 1.  1.  1.]

# 打印 tensor b
chkp.print_tensors_in_checkpoint_file("model/model.ckpt", tensor_name='b', all_tensors=False)
# tensor_name:  b
# [ 2.  2.  2.  2.  2.]
参考文件

猜你喜欢

转载自blog.csdn.net/lionel_fengj/article/details/80494367