tensorflow学习之Saver保存读取

  目前不是很懂。。但主要意思是tf可以把一开始定义的参数,包括Weights和Biases保存到本地,然后再定义一个变量框架去加载(restore)这个参数,作为变量本身的参数进行后续的训练,具体如下:

  

import numpy as np
#Save to file
 W = tf.Variable([[1,2,3],[3,4,5]],dtype=tf.float32,name='weights')
 b = tf.Variable([[1,2,3]],dtype=tf.float32,name='biases')

 init= tf.global_variables_initializer()

 saver = tf.train.Saver()

 with tf.Session() as sess:
     sess.run(init)
     save_path = saver.save(sess,"my_net/save_net.ckpt")
     print("Save to path:", save_path)

和代码同一目录下就出现了my_net这个文件夹,同时里面有了四个文件

然后,开始restore该参数

# restore variables
#redefine the same shape and same type for your variables
tf.reset_default_graph()
W = tf.Variable(np.arange(6).reshape((2,3)),dtype=tf.float32,name="weights")
b = tf.Variable(np.arange(3).reshape((1,3)),dtype=tf.float32,name="biases") 

#not need init step

saver = tf.train.Saver()
with tf.Session() as sess:
    saver.restore(sess,"my_net/save_net.ckpt")
    print("weights:", sess.run(W))
    print("biases:", sess.run(b))


#
INFO:tensorflow:Restoring parameters from my_net/save_net.ckpt
weights: [[1. 2. 3.]
 [3. 4. 5.]]
biases: [[1. 2. 3.]]

可以看到把原来的weights和biases都加载了

猜你喜欢

转载自www.cnblogs.com/yqpy/p/11042034.html