如何使用训练好的tensorflow

from tensorflow.python.framework import graph_util
 
constant_graph = graph_util.convert_variables_to_constants(sess, sess.graph_def, ["output"])
            with tf.gfile.FastGFile(pb_file_path, mode='wb') as f:
f.write(constant_graph.SerializeToString())

这两句是重要的代码,用来把训练好的模型保存为pb文件。运行完之后就会发现应该的文件夹多出了一个pb文件。

  1. test
def recognize(jpg_path, pb_file_path):
    with tf.Graph().as_default():
        output_graph_def = tf.GraphDef()

        with open(pb_file_path, "rb") as f:
            output_graph_def.ParseFromString(f.read())
_ = tf.import_graph_def(output_graph_def, name="")

打开相应的pb文件。

            img = io.imread(jpg_path)
            img = transform.resize(img, (224, 224, 3))
            img_out_softmax = sess.run(out_softmax, feed_dict={input_x:np.reshape(img, [-1, 224, 224, 3])})

读取图片文件,resize之后放入模型的输入位置,之后img_out_softmax就是相应输出的结果。

猜你喜欢

转载自blog.csdn.net/a13662080711/article/details/81540363