Tensorflow复习笔记3:Saver模块

模型保存

Saver 保存的是 图结构 和 session中的trainable的张量。
流程demo:

    saver = tf.train.Saver(max_to_keep=3)
    #...
    with tf.Session() as sess:
        # ...
        for step in range(1000000):
            # ...
            if step % 1000 == 0:
                # save session
                save_path = saver.save(sess, "ckpt/model", global_step=step)
                #print("Model saved in file: ", save_path)
  • tf.train.Saver()

    • max_to_keep参数:
      自动保存的最近n个ckpt文件。 默认n=5。若n=0或者None,则保存所有的ckpt文件。
      (循环保存,之前保存的模型会被删除掉。会替换已存在的同名文件)

    • keep_checkpoint_every_n_hour参数:
      每N小时保存一次checkpoint 文件。


模型恢复

流程demo:

# 恢复 图结构
saver = tf.train.import_meta_graph('my_test_model-1000.meta')

# 开启会话
with tf.Session() as sess:
    # 恢复 图结构中 最新的已存参数
    saver.restore(sess, tf.train.latest_checkpoint('./'))

    # 获取默认图结构,方便之后取出需要的节点
    graph = tf.get_default_graph()
    # 读取需要的Tensor node 
    input_x = graph.get_tensor_by_name("input_x:0")
    # ...

    # 之后正常的sess.run就好了
    y_pred = sess.run(y_output, feed_dict={input_x: test_image})
一点经验
  • 只需要用get_tensor_by_name函数来获取模型的“头”和“尾”节点,即可开始跑模型。
  • 模型恢复时,不需要再对变量做初始化。

代码部分

延用了上一篇的代码。其中因为 执行会话阶段 之前的代码和上一篇的完全一毛一样,就不复制过来了。

改动部分:
1. 经过若干轮训练, 输出当前准确度时,顺便测试了验证集的准确率,方便对比过拟合的程度。
2. 加入了模型保存模块。

###### 执行会话阶段
# 定义保存器
saver = tf.train.Saver(max_to_keep=3)
# 模型保存路径
save_path = 'ckpt/'

# 启动会话
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
with tf.Session(config=config) as sess :
    sess.run(tf.global_variables_initializer())

    start_time = time.time()
    # iteration
    for i in range(10000):
      batch = mnist.train.next_batch(32)
      # run. 训练阶段dropout设置为0.5,预测阶段设置为1.0
      sess.run(train_step, feed_dict={x: batch[0], y_: batch[1], keep_prob: 0.5})

      # 经过若干轮, 输出当前准确度
      if i%500 == 0:
        end_time = time.time()

        # 计算当前训练集的准确率
        train_accuracy = sess.run(accuracy, feed_dict={x: batch[0], y_: batch[1], keep_prob: 1.0})
        # 计算当前验证集的准确率
        vat_accuracy = sess.run(accuracy, feed_dict={x:mnist.validation.images, y_: mnist.validation.labels, keep_prob: 1.0})

        # 输出当前准确率
        print("step %d, training accuracy :%.4g, validation accuracy :%.4g, used time: %.2fs"%(i, train_accuracy, vat_accuracy, end_time-start_time))

        start_time = time.time()

      # 保存模型
        saver.save(sess, save_path+'mnist.ckpt', global_step=i)


    # 对测试集进行测试
    # 分段测试
    test_accuracys = []
    for i in range(0, int(mnist.test.images.shape[0]), 32):
        test_images = mnist.test.images[i:i + 32]
        test_labels = mnist.test.labels[i:i + 32]
        test_accuracy = sess.run(accuracy, feed_dict={x: test_images, y_: test_labels, keep_prob: 1.0})
        test_accuracys.append(test_accuracy)
    print("test accuracy :%g" % (sum(test_accuracys) / len(test_accuracys)))

我的输出结果:

...
step 9000, training accuracy :1, validation accuracy :0.9884, used time: 3.96s
step 9500, training accuracy :0.9688, validation accuracy :0.9878, used time: 3.70s
test accuracy :0.990116

参考文章

https://blog.csdn.net/huachao1001/article/details/78501928
适合入门和备忘,基本操作都有了,举例十分丰富。翻译得很好。
https://www.cnblogs.com/hellcat/p/6925757.html#_label0_3
讲的比较全,提到了二进制模型加载方法?

猜你喜欢

转载自blog.csdn.net/yiranzhiliposui/article/details/81044185
今日推荐