tensorflow 模型封装使用

tensorflow 在训练完一个模型之后,如何将其应用到我的产品中呢?

参考源码:github:(https://github.com/azuredsky/tensorflow-tutorial-1)

暂且记录下关键步骤:

1.将训练好的模型转固化为二进制文本

工具使用freeze_graph.py工具

python freeze_graph.py --input_graph=../model/nn_model.pbtxt --input_checkpoint=../ckpt/nn_model.ckpt --output_graph=../model/nn_model_frozen.pb --output_node_names=output_node

参数1:./model/nn_model.pbtxt

通过tf.train.write_graph(session.graph_def, FLAGS.model_dir, "nn_model.pbtxt", as_text=True)生成实现,

将图中的计算节点保存在模型nn_model.pbtxt中

参数2:../ckpt/nn_model.ckpt 训练好的参数模型

参数3:../model/nn_model_frozen.pb 设置输出文件名

参数4:output_node,在图中输出节点的名字python 代码:

        for op in tf.get_default_graph().get_operations():

            print(op.name)

output_node 输出节点需要为在python实际sess.run的节点,否则图中不会保存。

在输出nn_model_frozen.pb就可以通过python 或者C++直接调用了

python示例:

def load_graph(fz_gh_fn):
    with tf.gfile.GFile(fz_gh_fn,"rb") as f:
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(f.read())

        with tf.Graph().as_default() as graph:
            tf.import_graph_def(
                graph_def,
                input_map = None,
                return_elements = None,
                name = "prefix"
            )
    return graph

定义了load_graph 就可以加载模型运行了

graph = load_graph(args.fz_model_fn)


    x = graph.get_tensor_by_name('prefix/inputs:0')
    y = graph.get_tensor_by_name('prefix/output_node:0')

    img = Image.open('./test_image/8_0031.bmp')
    flatten_img = np.reshape(img, [1, img_H, img_W, channels])

    with tf.Session(graph=graph) as sess:
        y_out=sess.run(y,feed_dict={x:flatten_img})
        print(y_out)

C++ 示例:

首先通过自己手动编译tensorflow 动态链接库

其次加载模型运行:

int main(int argc, char* argv[]) {
    // Initialize a tensorflow session
    Session* session;
    Status status = NewSession(SessionOptions(), &session);
    if (!status.ok()) {
        std::cout << status.ToString() << "\n";
        return 1;
    }

    // Read in the protobuf graph we exported
    // (The path seems to be relative to the cwd. Keep this in mind
    // when using `bazel run` since the cwd isn't where you call
    // `bazel run` but from inside a temp folder.)
    GraphDef graph_def;
    status = ReadBinaryProto(Env::Default(), "nn_model_frozen.pb", &graph_def);
    if (!status.ok()) {
        std::cout << status.ToString() << "\n";
        return 1;
    }

    // Add the graph to the session
    status = session->Create(graph_def);
    if (!status.ok()) {
        std::cout << status.ToString() << "\n";
        return 1;
    }

    // TensorName pre-defined in python file, Need to extract values from tensors
    std::string input_tensor_name = "inputs:0";
    std::string output_tensor_name = "output_node:0";

    Tensor x(DT_FLOAT, TensorShape({1, 32, 24, 3})); // New Tensor shape [1, ndim]
    Mat img;
    img = imread("4_0179.bmp",1);
    imshow("img", img);
    waitKey(10);
    auto x_map = x.tensor<float, 4>(); // == x.scalar<float>()
    cvtColor(img, img, CV_RGB2BGR);
    img.convertTo(img,CV_32FC3);
    float *data = x_map.data();
    memcpy(x_map.data(), (float*)img.data, 1*24*32*3*sizeof(float));
    // Setup inputs and outputs:
    std::vector<std::pair<string, Tensor>> inputs = {
        { input_tensor_name, x }};

    // The session will initialize the outputs
    std::vector<Tensor> outputs;

    // Run the session, evaluating our "y" operation from the graph
    status = session->Run(inputs, { output_tensor_name }, {}, &outputs);
    if (!status.ok()) {
        std::cout << status.ToString() <<__LINE__<< "\n";
        return 1;
    }

    // Grab the first output (we only evaluated one graph node: "c")
    // and convert the node to a scalar representation.
    auto output_y = outputs[0].scalar<int>();

    // (There are similar methods for vectors and matrices here:
    // https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/public/tensor.h)

    // Print the results
    std::cout << outputs[0].DebugString() << "\n"; // Tensor<type: float shape: [] values: 32>
    std::cout << output_y() << "\n"; // 32

                                     // Free any resources used by the session

    session->Close();

    return 0;
}
在其中注意tensorflow 使用bgr格式,所以加载自己的图片时需要将其转化为bgr格式。

猜你喜欢

转载自blog.csdn.net/chfeilong0202/article/details/80619155