Tensorflow如何保存、读取model (即利用训练好的模型测试新数据的准确度)

目标:

cnn2d.py cnn2d_test.py
训练网络,并保存网络模型 读取网络,用测试集测试准确度

直接贴代码:(只贴了相关部分,浏览完整代码请到GitHub

1. cnn2d.py

import tensorflow as tf
import numpy as np
from sklearn import metrics

print("### Process1 --- data load ###")
# 读取数据
print("### Process2 --- data spilt ###")
# 形成训练集和验证集

# 定义
# ···
X = tf.placeholder(tf.float32, (None, seg_height, seg_len, num_channels), name='X')
Y = tf.placeholder(tf.float32, (None, num_labels), name='Y')
# 注意name='X'和name='Y'

# 网络结构
# ···

# loss training等定义
# ···

accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32), name="accuracy")
# 注意name="accuracy"

with tf.Session() as session:
    # ···
    saver = tf.train.Saver()
    tf.train.Saver().save(session, "./model/HAR-UCI_model_1")

# Epoch:  100  Training Loss:  0.11656471  Training Accuracy:  0.9664903
# Epoch:  100 Valid Accuracy: 0.96754116
# Epoch:  100 Test Accuracy: 0.9321344
# ### Save model_1 successfully ###

虽然只给X, Y, accuracy命名,但网络其余结构、参数均自动分配了名字并保存在./model/中。

./model/中保存的文件:

只给X, Y, accuracy命名是因为在下面一个程序中只用到了这三个参数。

2. cnn2d_test.py

import tensorflow as tf
import numpy as np

# 读取测试集test_x和test_y
# ···

saver = tf.train.import_meta_graph("./model/HAR-UCI_model_1.meta")
with tf.Session() as session:
    saver.restore(session, tf.train.latest_checkpoint("./model/"))
    graph = tf.get_default_graph()
    feed_dict = {"X:0": test_x, "Y:0": test_y}
    acc = graph.get_tensor_by_name("accuracy:0")
    test_acc = session.run(acc, feed_dict=feed_dict)
    print("Test Accuracy:", test_acc)

# Test Accuracy: 0.9321344

猜你喜欢

转载自blog.csdn.net/kane7csdn/article/details/83794941