tensorflow 的saver

英文介绍:https://cv-tricks.com/tensorflow-tutorial/save-restore-tensorflow-models-quick-complete-tutorial/

每个模型存储后都有三个文件,分别是 .meta ,.data-00000-of-00001和.index

  • .meta存储模型的结构,变量等
  • .data-00000-of-00001和.index统称为Checkpoint file,存储模型经过训练后权重的值

存模型

import tensorflow as tf
w1 = tf.Variable(tf.random_normal(shape=[2]), name='w1')
w2 = tf.Variable(tf.random_normal(shape=[5]), name='w2')
saver = tf.train.Saver()
sess = tf.Session()
sess.run(tf.global_variables_initializer())
saver.save(sess, 'my_test_model')

# This will save following files in Tensorflow v >= 0.11
# my_test_model.data-00000-of-00001
# my_test_model.index
# my_test_model.meta
# checkpoint

重点是这句saver.save(sess, 'my_test_model'),会生成一个系列名为my_test_model的模型存储文件。

恢复的时候需要:

sess = tf.Session()
saver = tf.train.import_meta_graph('my_test_model.meta')
saver.restore(sess,tf.train.latest_checkpoint('./'))

那个latest_checkpoint是在所给的文件夹中找相应的checkpoint文件,不用把文件的全部路径给出,只给相应的文件夹即可

然后是恢复权重:

w1 = sess.run('w1:0')
w2 = sess.run('w2:0')

a1 = tf.matmul(x_data,w1)
a2 = tf.matmul(a1,w2)

y_output = sess.run(a2)

就可以了

转载于:https://www.jianshu.com/p/aa89aac40d03

猜你喜欢

转载自blog.csdn.net/weixin_33797791/article/details/91070781