tensorflow之pd模型

pb格式,可以把训练好的模型的参数固话,便于调用。

举个示例:

with tf.Session(graph=tf.Graph()) as sess:
    x = tf.placeholder(tf.int32, name='x')
    y = tf.placeholder(tf.int32, name='y')
    b = tf.Variable(1, name='b')
    xy = tf.multiply(x, y)
    # 这里的输出需要加上name属性
    op = tf.add(xy, b, name='op_to_store')
    sess.run(tf.global_variables_initializer())
    # convert_variables_to_constants 需要指定output_node_names,list(),可以多个
    constant_graph = tf.graph_util.convert_variables_to_constants(sess, sess.graph_def, ['op_to_store'])
    # 测试 OP
    feed_dict = {x: 10, y: 3}
    print(sess.run(op, feed_dict))
    # 写入序列化的 PB 文件
    with tf.gfile.FastGFile('model.pb', mode='wb') as f:
        f.write(constant_graph.SerializeToString())
    # 输出
    # INFO:tensorflow:Froze 1 variables.
    # Converted 1 variables to const ops.
    # 31

输出文件如下:

 这个pb文件如何调用呢?

我们再写个demo

sess = tf.Session()
with tf.gfile.FastGFile('model.pb', 'rb') as f:
    graph_def = tf.GraphDef()
    graph_def.ParseFromString(f.read())
    sess.graph.as_default()
    tf.import_graph_def(graph_def, name='') # 导入计算图
# 需要有一个初始化的过程    
sess.run(tf.global_variables_initializer())
# 需要先复原变量
print(sess.run('b:0'))
# 1
# 输入
input_x = sess.graph.get_tensor_by_name('x:0')
input_y = sess.graph.get_tensor_by_name('y:0')
op = sess.graph.get_tensor_by_name('op_to_store:0')
ret = sess.run(op,  feed_dict={input_x: 5, input_y: 5})
print(ret)
# 输出 26

 这里根据名称获取张量为什么是x:0而不是x呢?

继续研究一下:打个断点发现name确实是x:0,至于为什么,现在还不是太清楚,后续再慢慢体会。

***********************由于我们服务是用C#调用的,所以决定尝试一下C#调用pb文件**********

http://imlihang.cn/?p=183

扫描二维码关注公众号,回复: 5092959 查看本文章

参考上面牛人,表示感谢

首先安装

然后就可以直接调用了

Console.WriteLine("Hello World!");
            var graph = new TFGraph();
            var model = File.ReadAllBytes("model.pb");
            graph.Import(model);

            using (var session = new TFSession(graph))
            {
                var runner = session.GetRunner();
                TFTensor xTf = new TFTensor(new int[2] { 10,15});
                TFTensor yTf = new TFTensor(new int[2] { 2, 2 });
                // 其中的graph["input"][0], graph["output"][0]指的是,input节点的第1个输出,和 output节点的第1个输出,等同于python中的input:0 output:0 
                // 其中Fetch()用于取得输出变量。 
                runner.AddInput(graph["x"][0], xTf);
                runner.AddInput(graph["y"][0], yTf);
                runner.Fetch(graph["op_to_store"][0]);
                var output = runner.Run();
                var result = output[0];
                var iData = result.GetValue(true);
                Console.ReadKey();
            }

            Console.ReadKey();

 看下运行结果:

 是不是完全符合预期结果。

猜你喜欢

转载自blog.csdn.net/g0415shenw/article/details/86597817
今日推荐