Tensorflow变量的保存和恢复

版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/hello_java_Android/article/details/84342410

Tensorflow使用tf.train.Saver来进行变量的保存和恢复,当然还可以进行整个模型的保存,以下代码实例展示了如何进行模型和变量的保存

#-*- coding=utf-8 -*-
import tensorflow as tf

# sess = tf.InteractiveSession()

weights = tf.Variable(tf.random_normal([784, 200], stddev=0.35), name='weights')
bias = tf.Variable(tf.zeros([200], name='bias'))
w2 = tf.Variable(weights.initialized_value, name="w2")
w3 = tf.Variable(weights.initialized_value() * 0.2, name="w3")

test = tf.Variable(300, name='test')
# init_op = tf.global_variables_initializer()
init_op = tf.initialize_all_variables()
# sess.run(init_op)
#add an op to save and restore all the variables
saver = tf.train.Saver({'weights':w3,'test':test})

# w2 = tf.Variable(weights.initialized_value, name="w2")
with tf.Session() as sess:
    sess.run(init_op)
    print('weights is: \n', weights.eval())
    print('w2 is: \n', w2.eval())
    print('w3 is: \n', w3.eval())
    save_path = saver.save(sess, './tmp/saver/model.ckpt')
    writer = tf.summary.FileWriter('./tmp/log_dir', sess.graph)
    print("vairables saved in :", save_path)

# with tf.Session() as sess:
#     saver.restore(sess, './tmp/saver/model.ckpt')
#     print('variables restored')
# print('relued weights is: \n', tf.nn.relu(weights).eval())

tf.train.Saver({'weights':w3,'test':test})这句好就是具体变量的保存了,使用键值对的形式,多个变量用逗号分割,最终通过saver.save(sess,'./tmp/saver/model.ckpt'),这里传入两个参数,一个是sess,会话对象,.另一个是指定保存路径,恢复时要与此路径一致才可以正常恢复,那么模型怎么保存呢?

其实如果不指定具体要保存的变量就是保存整个模型,也就是没有这句话:saver = tf.train.Saver({'weights':w3,'test':test})

看明白了没?

恢复操作也是通过Saver来实现:

#-*- coding=utf8 -*-
import tensorflow as tf

test = tf.Variable(tf.random_normal([784, 200]), name='weights')
saver = tf.train.Saver()

with tf.Session() as sess:
    saver.restore(sess, './tmp/saver/model.ckpt')
    print('test = ', sess.run(test))

要恢复哪个变量呢?  根据保存的时候键值对的key来指定要恢复的变量,但是注意shape要一致,然后创建saver,通过sess和路径来恢复

猜你喜欢

转载自blog.csdn.net/hello_java_Android/article/details/84342410
今日推荐