Tensorflow之模型参数的Saver保存读取

一、Saver保存

import tensorflow as tf
import numpy as np

#定义W和b
W = tf.Variable([[1,2,3],[3,5,6]],dtype = tf.float32,name = 'weight')
b = tf.Variable([1,2,3],dtype = tf.float32,name = 'biases')
#注:初始化变量Variable
init = tf.global_variables_initializer()


#建立tf.train.Saver() 来保存, 提取变量。
#建立my_net文件夹,保存变量
saver =  tf.train.Saver()

sess = tf.Session()
sess.run(init)
#保存变量到路径my_net
save_path = saver.save(sess,"my_net/save_net.ckpt")#保存格式为ckpt

#输出保存的变量
print("save path:",save_path)

结果: 

二、Saver读取

import tensorflow as tf
import numpy as np


#建立W,b的空容器
W = tf.Variable(np.arange(6).reshape((2, 3)), dtype=tf.float32, name="weights")
b = tf.Variable(np.arange(3).reshape((1, 3)), dtype=tf.float32, name="biases")

#不需要初始化变量

saver = tf.train.Saver()
with tf.Session() as sess:
    saver.restore(sess, "my_net/save_net.ckpt")
    print("weights:", sess.run(W))
    print("biases:", sess.run(b))

猜你喜欢

转载自blog.csdn.net/qq_33373858/article/details/83684669