恢复模型和获取特定张量:
从pb文件中恢复模型:
with tf.gfile.GFile('./frozen.pb','rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
tf.import_graph_def(graph_def, name='')
从ckpt和meta文件中恢复,ckpt文件保存的是模型的张量数据,meta保存的是图的架构,先导入图,然后导入各个张量的值。
meta_file, ckpt_file = 'model.meta','model.ckpt'
saver = tf.train.import_meta_graph(meta_file, input_map=input_map)
saver.restore(tf.get_default_session(), ckpt_file)
找到特定张量,使用函数graph.get_tensor_by_name,得到的是个张量,可以直接作为Session.run(graph.get_tensor_by_name(' ')),如果张量需要输入,还需要有feed_dict。取graph的部分子图,可以用来feed相应的张量,然后输出想要的值。
embeddings = tf.get_default_graph().get_tensor_by_name("embeddings:0")
print(embeddings)
输出:Tensor("embeddings:0", shape=(?, 512), dtype=float32)
找到特定操作Operation,使用函数graph.get_operation_by_name,得到的是操作节点,字典形式,其中包括输入和输出张量
input_operation = graph.get_operation_by_name("import/Mul")
output_operation = graph.get_operation_by_name("import/final_result")
print(input_operation)
print(output_operation)
输出:
name: "import/Mul"
op: "Mul"
input: "import/Sub"
input: "import/Mul/y"
attr {
key: "T"
value {
type: DT_FLOAT
}
}
name: "import/final_result"
op: "Softmax"
input: "import/final_training_ops/Wx_plus_b/add"
attr {
key: "T"
value {
type: DT_FLOAT
}
}
可以通过字典找到相应的输出
print(input_operation.name)
print(output_operation.name)
print(output_operation.inputs[0])
print(input_operation.outputs)
输出:
import/Mul
import/final_result
Tensor("import/final_training_ops/Wx_plus_b/add:0", shape=(?, 9), dtype=float32)
[<tf.Tensor 'import/Mul:0' shape=(1, 299, 299, 3) dtype=float32>]扫描二维码关注公众号,回复: 2873754 查看本文章