tensorflow之train.get_checkpoint_state

tf.train.get_checkpoint_state函数通过checkpoint文件找到模型文件名。

该函数返回的是checkpoint文件CheckpointState proto类型的内容,其中有model_checkpoint_path和all_model_checkpoint_paths两个属性。其中model_checkpoint_path保存了最新的tensorflow模型文件的文件名,all_model_checkpoint_paths则有未被删除的所有tensorflow模型文件的文件名。
 

        while True:
            with tf.Session() as sess:
                ckpt = tf.train.get_checkpoint_state(mnist_backward.MODEL_SAVE_PATH)
                if ckpt and ckpt.model_checkpoint_path:
                    saver.restore(sess, ckpt.model_checkpoint_path)
                    global_step = ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1]
                    accuracy_score = sess.run(accuracy, feed_dict={x: mnist.test.images, y_: mnist.test.labels})
                    print("After %s training step(s), test accuracy = %g" % (global_step, accuracy_score))
                else:
                    print('No checkpoint file found')
                    return

首先通过get_checkpoint_state检查路局下是否存在训练好的模型

再通过restore恢复模型

通过文件名还能获取到global_step
再直接调用就可以把训练好的模型拿来使用

猜你喜欢

转载自blog.csdn.net/g0415shenw/article/details/85779002