TensorFlow保存以及恢复模型找到特定张量以及操作

恢复模型和获取特定张量:
从pb文件中恢复模型:

with tf.gfile.GFile('./frozen.pb','rb') as f:
    graph_def = tf.GraphDef()
    graph_def.ParseFromString(f.read())
    tf.import_graph_def(graph_def, name='')

从ckpt和meta文件中恢复,ckpt文件保存的是模型的张量数据,meta保存的是图的架构,先导入图,然后导入各个张量的值。

meta_file, ckpt_file = 'model.meta','model.ckpt'
saver = tf.train.import_meta_graph(meta_file, input_map=input_map)
saver.restore(tf.get_default_session(), ckpt_file)

找到特定张量,使用函数graph.get_tensor_by_name,得到的是个张量,可以直接作为Session.run(graph.get_tensor_by_name(' ')),如果张量需要输入,还需要有feed_dict。取graph的部分子图,可以用来feed相应的张量,然后输出想要的值。

embeddings = tf.get_default_graph().get_tensor_by_name("embeddings:0")
print(embeddings)

输出:Tensor("embeddings:0", shape=(?, 512), dtype=float32)

找到特定操作Operation,使用函数graph.get_operation_by_name,得到的是操作节点,字典形式,其中包括输入和输出张量

input_operation = graph.get_operation_by_name("import/Mul")
output_operation = graph.get_operation_by_name("import/final_result")
print(input_operation)
print(output_operation)

输出: 

name: "import/Mul"
op: "Mul"
input: "import/Sub"
input: "import/Mul/y"
attr {
  key: "T"
  value {
    type: DT_FLOAT
  }
}


name: "import/final_result"
op: "Softmax"
input: "import/final_training_ops/Wx_plus_b/add"
attr {
  key: "T"
  value {
    type: DT_FLOAT
  }
}

 可以通过字典找到相应的输出

print(input_operation.name)
print(output_operation.name)
print(output_operation.inputs[0])
print(input_operation.outputs)

输出: 

import/Mul
import/final_result
Tensor("import/final_training_ops/Wx_plus_b/add:0", shape=(?, 9), dtype=float32)
[<tf.Tensor 'import/Mul:0' shape=(1, 299, 299, 3) dtype=float32>]

扫描二维码关注公众号,回复: 2873754 查看本文章

猜你喜欢

转载自blog.csdn.net/shiheyingzhe/article/details/81840007