简单的模型保存与恢复 ——tensorflow

目录

 

保存

保存后进行恢复


保存

import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt

# Prepare train data
train_X = np.linspace(-1, 1, 100)
train_Y = 2 * train_X + np.random.randn(*train_X.shape) * 0.33 + 10

# Define the model
X = tf.placeholder("float")
Y = tf.placeholder("float")
w = tf.Variable(0.0, name="weight")
b = tf.Variable(0.0, name="bias")
loss = tf.square(Y - X * w - b)
optimizer = tf.train.GradientDescentOptimizer(0.01)
grads_and_vars = optimizer.compute_gradients(loss)
train_op = optimizer.apply_gradients((grads_and_vars))
# train_op = tf.train.GradientDescentOptimizer(0.01).minimize(loss)
saver = tf.train.Saver()      #定义保存器
# Create session to run
with tf.Session() as sess:
    sess.run(tf.initialize_all_variables())

    epoch = 1
    for i in range(10):
        for (x, y) in zip(train_X, train_Y):
            _, w_value, b_value = sess.run([train_op, w, b], feed_dict={X: x, Y: y})
        print("Epoch: {}, w: {}, b: {}".format(epoch, w_value, b_value))
        epoch += 1
        saver.save(sess,"./shit.tf")#进行保存
# draw
plt.plot(train_X, train_Y, "+")
plt.plot(train_X, train_X.dot(w_value) + b_value)
plt.show()

上面是一个简单的tensorflow的代码,主要是训练w和b来模拟一些点的走向,这里的重点是模型保存。

保存代码:tf.train.Saver()和saver.save(sess,"./shit.tf")

保存后进行恢复

import tensorflow as tf
import numpy as np

w = tf.Variable(0.0, name="weight")
b = tf.Variable(0.0, name="bias")

#restore
model_file = tf.train.latest_checkpoint('./')
saver = tf.train.Saver();
print(model_file)
with tf.Session() as sess:
    saver.restore(sess,model_file)
    print(sess.run(w))
    print(sess.run(b))

 恢复时定义好在保存时定义的变量就可以恢复了,可以只定义你想要的变量

结果截图:

可以看到 ,我们定义的是0,恢复后再输出结果就变了,变成了我们之前保存的数了。

猜你喜欢

转载自blog.csdn.net/hb_688/article/details/81428998
今日推荐