模型保存
Saver 保存的是 图结构 和 session中的trainable的张量。
流程demo:
saver = tf.train.Saver(max_to_keep=3)
#...
with tf.Session() as sess:
# ...
for step in range(1000000):
# ...
if step % 1000 == 0:
# save session
save_path = saver.save(sess, "ckpt/model", global_step=step)
#print("Model saved in file: ", save_path)
tf.train.Saver()
max_to_keep参数:
自动保存的最近n个ckpt文件。 默认n=5。若n=0或者None,则保存所有的ckpt文件。
(循环保存,之前保存的模型会被删除掉。会替换已存在的同名文件)keep_checkpoint_every_n_hour参数:
每N小时保存一次checkpoint 文件。
模型恢复
流程demo:
# 恢复 图结构
saver = tf.train.import_meta_graph('my_test_model-1000.meta')
# 开启会话
with tf.Session() as sess:
# 恢复 图结构中 最新的已存参数
saver.restore(sess, tf.train.latest_checkpoint('./'))
# 获取默认图结构,方便之后取出需要的节点
graph = tf.get_default_graph()
# 读取需要的Tensor node
input_x = graph.get_tensor_by_name("input_x:0")
# ...
# 之后正常的sess.run就好了
y_pred = sess.run(y_output, feed_dict={input_x: test_image})
一点经验
- 只需要用
get_tensor_by_name
函数来获取模型的“头”和“尾”节点,即可开始跑模型。 - 模型恢复时,不需要再对变量做初始化。
代码部分
延用了上一篇的代码。其中因为 执行会话阶段 之前的代码和上一篇的完全一毛一样,就不复制过来了。
改动部分:
1. 经过若干轮训练, 输出当前准确度时,顺便测试了验证集的准确率,方便对比过拟合的程度。
2. 加入了模型保存模块。
###### 执行会话阶段
# 定义保存器
saver = tf.train.Saver(max_to_keep=3)
# 模型保存路径
save_path = 'ckpt/'
# 启动会话
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
with tf.Session(config=config) as sess :
sess.run(tf.global_variables_initializer())
start_time = time.time()
# iteration
for i in range(10000):
batch = mnist.train.next_batch(32)
# run. 训练阶段dropout设置为0.5,预测阶段设置为1.0
sess.run(train_step, feed_dict={x: batch[0], y_: batch[1], keep_prob: 0.5})
# 经过若干轮, 输出当前准确度
if i%500 == 0:
end_time = time.time()
# 计算当前训练集的准确率
train_accuracy = sess.run(accuracy, feed_dict={x: batch[0], y_: batch[1], keep_prob: 1.0})
# 计算当前验证集的准确率
vat_accuracy = sess.run(accuracy, feed_dict={x:mnist.validation.images, y_: mnist.validation.labels, keep_prob: 1.0})
# 输出当前准确率
print("step %d, training accuracy :%.4g, validation accuracy :%.4g, used time: %.2fs"%(i, train_accuracy, vat_accuracy, end_time-start_time))
start_time = time.time()
# 保存模型
saver.save(sess, save_path+'mnist.ckpt', global_step=i)
# 对测试集进行测试
# 分段测试
test_accuracys = []
for i in range(0, int(mnist.test.images.shape[0]), 32):
test_images = mnist.test.images[i:i + 32]
test_labels = mnist.test.labels[i:i + 32]
test_accuracy = sess.run(accuracy, feed_dict={x: test_images, y_: test_labels, keep_prob: 1.0})
test_accuracys.append(test_accuracy)
print("test accuracy :%g" % (sum(test_accuracys) / len(test_accuracys)))
我的输出结果:
...
step 9000, training accuracy :1, validation accuracy :0.9884, used time: 3.96s
step 9500, training accuracy :0.9688, validation accuracy :0.9878, used time: 3.70s
test accuracy :0.990116
参考文章
https://blog.csdn.net/huachao1001/article/details/78501928
适合入门和备忘,基本操作都有了,举例十分丰富。翻译得很好。
https://www.cnblogs.com/hellcat/p/6925757.html#_label0_3
讲的比较全,提到了二进制模型加载方法?