pb格式,可以把训练好的模型的参数固话,便于调用。
举个示例:
with tf.Session(graph=tf.Graph()) as sess:
x = tf.placeholder(tf.int32, name='x')
y = tf.placeholder(tf.int32, name='y')
b = tf.Variable(1, name='b')
xy = tf.multiply(x, y)
# 这里的输出需要加上name属性
op = tf.add(xy, b, name='op_to_store')
sess.run(tf.global_variables_initializer())
# convert_variables_to_constants 需要指定output_node_names,list(),可以多个
constant_graph = tf.graph_util.convert_variables_to_constants(sess, sess.graph_def, ['op_to_store'])
# 测试 OP
feed_dict = {x: 10, y: 3}
print(sess.run(op, feed_dict))
# 写入序列化的 PB 文件
with tf.gfile.FastGFile('model.pb', mode='wb') as f:
f.write(constant_graph.SerializeToString())
# 输出
# INFO:tensorflow:Froze 1 variables.
# Converted 1 variables to const ops.
# 31
输出文件如下:
这个pb文件如何调用呢?
我们再写个demo
sess = tf.Session()
with tf.gfile.FastGFile('model.pb', 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
sess.graph.as_default()
tf.import_graph_def(graph_def, name='') # 导入计算图
# 需要有一个初始化的过程
sess.run(tf.global_variables_initializer())
# 需要先复原变量
print(sess.run('b:0'))
# 1
# 输入
input_x = sess.graph.get_tensor_by_name('x:0')
input_y = sess.graph.get_tensor_by_name('y:0')
op = sess.graph.get_tensor_by_name('op_to_store:0')
ret = sess.run(op, feed_dict={input_x: 5, input_y: 5})
print(ret)
# 输出 26
这里根据名称获取张量为什么是x:0而不是x呢?
继续研究一下:打个断点发现name确实是x:0,至于为什么,现在还不是太清楚,后续再慢慢体会。
***********************由于我们服务是用C#调用的,所以决定尝试一下C#调用pb文件**********
扫描二维码关注公众号,回复:
5092959 查看本文章
参考上面牛人,表示感谢
首先安装
然后就可以直接调用了
Console.WriteLine("Hello World!");
var graph = new TFGraph();
var model = File.ReadAllBytes("model.pb");
graph.Import(model);
using (var session = new TFSession(graph))
{
var runner = session.GetRunner();
TFTensor xTf = new TFTensor(new int[2] { 10,15});
TFTensor yTf = new TFTensor(new int[2] { 2, 2 });
// 其中的graph["input"][0], graph["output"][0]指的是,input节点的第1个输出,和 output节点的第1个输出,等同于python中的input:0 output:0
// 其中Fetch()用于取得输出变量。
runner.AddInput(graph["x"][0], xTf);
runner.AddInput(graph["y"][0], yTf);
runner.Fetch(graph["op_to_store"][0]);
var output = runner.Run();
var result = output[0];
var iData = result.GetValue(true);
Console.ReadKey();
}
Console.ReadKey();
看下运行结果:
是不是完全符合预期结果。