tensorflow模型的保存与加载

1.保存:(保存的变量都是停放,tf.Variable()中的变量,变量一定要有名字)

saver = tf.train.Saver()

saver.run(sess,"./model4/line_model.ckpt")

2.查看保存的变量信息:(将保存的信息打印出来)

from tensorflow.tools.inspect_checkpoint import print_tensors_in_checkpoint_file
print_tensors_in_checkpoint_file("./model4/liner_model.ckpt",None,Ture)

3.将文件中的参数加载到模型中:

saver.restore(sess,"./model4/line_model.ckpt")
w1 = sess.run(w,feed_dict={x:batch_xs,y:batch_ys})
b1 = sess.run(b,feed_dict={x:batch_xs,y:batch_ys})
print("w1:",w1)
print("b1:",b1)

完整代码奉上:

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data

mnist = input_data.read_data_sets("MNIST_data/",one_hot=True)
#导入数据集
x = tf.placeholder(shape=[None,784],dtype=tf.float32)
y = tf.placeholder(shape=[None,10],dtype=tf.float32)
#为输入输出定义placehloder

w = tf.Variable(tf.truncated_normal(shape=[784,10],mean=0,stddev=0.5),name="W")
b = tf.Variable(tf.zeros([10]),name="b")
#定义权重
y_pred = tf.nn.softmax(tf.matmul(x,w)+b)
#定义模型结构
loss =tf.reduce_mean(-tf.reduce_sum(y*tf.log(y_pred),reduction_indices=[1]))
#定义损失函数
opt = tf.train.GradientDescentOptimizer(0.05).minimize(loss)
#定义优化算法
saver = tf.train.Saver()#保存模型
sess =tf.Session()
sess.run(tf.global_variables_initializer())
for each in range(1000):
    batch_xs,batch_ys = mnist.train.next_batch(100)
    loss1 = sess.run(loss,feed_dict={x:batch_xs,y:batch_ys})
    opt1 = sess.run(opt,feed_dict={x:batch_xs,y:batch_ys})
    # print(loss1)
    # saver.save(sess,"./model4/line_model.ckpt")#将保存的模型放在指定的文件中
# w1 = sess.run(w,feed_dict={x:batch_xs,y:batch_ys})
# print(w1)
# b1 = sess.run(b,feed_dict={x:batch_xs,y:batch_ys})
# print("b:",b)
# y1 = sess.run(y_pred,feed_dict={x:batch_xs,y:batch_ys})
# print("y:",y1)
# print(len(y1))
#打印保存的模型参数
# from tensorflow.python.tools.inspect_checkpoint import print_tensors_in_checkpoint_file
# print_tensors_in_checkpoint_file("./model4/line_model.ckpt",None,True)

#将指定文件中的变量加载到模型中
saver.restore(sess,"./model4/line_model.ckpt")
w1 = sess.run(w,feed_dict={x:batch_xs,y:batch_ys})
b1 = sess.run(b,feed_dict={x:batch_xs,y:batch_ys})
print("w1:",w1)
print("b1:",b1)

猜你喜欢

转载自blog.csdn.net/qq_41853536/article/details/83214602