TensorFlow Graph Save&Parse&Load、Freeze、Optimize

一、Graph Save&Parser&Load

1、函数解析

(a)、Graph Save

tf.train.write_graph(graph_or_graph_def, logdir, name, as_text=True)

作用:

  • save a graph proto to a file

输入参数:

  • graph_or_graph_def: A Graph or a GraphDef protocol buffer
  • logdir: Directory where to write the graph
  • name: Filename for the graph
  • as_text: If True, writes the graph as an ASCII proto

返回:

  • The path of the output proto file.

graph_def.SerializeToString()

作用:

  • Serializes the protocol message to a binary string

返回:

  • A binary string representation of the message

(b)、Graph Parser

input_graph_def.ParseFromString(self, serialized=f.read())

作用:

  • Parse serialized protocol buffer data into this message(适用二进制的 pb 文件)

输入参数:

  • serialized: A serialized protocol buffer data

返回:

  • 解析后的 input_graph_def

text_format.Merge(text=f.read(), message=input_graph_def)

作用:

  • Parses a text representation of a protocol message into a message(适用文本形式的 pb 文件)

输入参数:

  • text: Message text representation
  • message: A protocol buffer message to merge into

返回:

  • The same message passed as argument

(c)、Graph Load

tf.import_graph_def(graph_def, input_map=None, return_elements=None, name=None)

作用:

  • Imports the graph from graph_def into the current default Graph

输入参数:

  • graph_def: A GraphDef proto containing operations to be imported into the default graph
  • input_map: A dictionary mapping input names (as strings) in graph_def to Tensor objects. The values of the named input tensors in the imported graph will be re-mapped to the respective Tensor values.
  • return_elements: A list of strings containing operation names in graph_def that will be returned as Operation objects; and/or tensor names in graph_def that will be returned as Tensor objects
  • name: A prefix that will be prepended to the names in graph_def. Note that this does not apply to imported function names. Defaults to import

返回:

  • A list of Operation and/or Tensor objects from the imported graph, corresponding to the names in return_elements.

2、代码示例

# (a)、graph save
v = tf.Variable(0, name='my_variable')
sess = tf.Session()
tf.train.write_graph(sess.graph, '/tmp/my-model', 'train.pbtxt', as_text=True)
# or tf.train.write_graph(sess.graph_def, '/tmp/my-model', 'train.pbtxt')


# (b)、graph parse
def _parse_input_graph_proto(input_graph, input_binary):
    if not gfile.Exists(input_graph):
        print("Input graph file '" + input_graph + "' does not exist!")
        return -1
    # Crate an empty CraphDef object
    input_graph_def = graph_pb2.GraphDef()
    mode = "rb" if input_binary else "r"
    with gfile.FastGFile(input_graph, mode) as f:
        if input_binary:
            input_graph_def.ParseFromString(f.read())
        else:
            text_format.Merge(f.read(), input_graph_def)
            # 上面若出错,可以试试使用 utf-8 解码
            # text_format.Merge(f.read().decode("utf-8"), input_graph_def)
    return input_graph_def

# (c)、graph load
if input_graph_def:
    _ = importer.import_graph_def(input_graph_def, name="")

二、Graph Freeze

1、函数解析

tf.graph_util.convert_variables_to_constants(sess, input_graph_def, output_node_names)

输入参数:

  • sess: Active TensorFlow session containing the variables
  • input_graph_def: GraphDef object holding the network
  • output_node_names:a list of the names of the nodes that you want to extract the results of your graph from

返回:

  • GraphDef containing a simplified version of the original.

2、代码示例

# 略过计算图中没有保存的节点
with session.Session() as sess:
     var_list = {}
     reader = pywrap_tensorflow.NewCheckpointReader(input_checkpoint)
     var_to_shape_map = reader.get_variable_to_shape_map()
     for key in var_to_shape_map:
         try:
             tensor = sess.graph.get_tensor_by_name(key + ":0")
         except KeyError:
             # This tensor doesn't exist in the graph (for example it's
             # 'global_step' or a similar housekeeping element) so skip it.
             continue
         var_list[key] = tensor
     saver = saver_lib.Saver(var_list=var_list)
     saver.restore(sess, input_checkpoint)

     # freeze 操作 
     output_graph_def = graph_util.convert_variables_to_constants(
            sess,
            input_graph_def,
            output_node_names.split(","),
            variable_names_whitelist=variable_names_whitelist,
            variable_names_blacklist=variable_names_blacklist)

# Write GraphDef to file if output path has been given.
if output_graph:
    with gfile.GFile(output_graph, "wb") as f:
        f.write(output_graph_def.SerializeToString())

三、Graph Optimize

1、安装编译工具 Bazel

  • 安装所需的包
    • sudo apt-get install pkg-config zip g++ zlib1g-dev unzip python
  • 下载 Bazel
  • 修改文件权限并执行安装
    • chmod +x bazel-<version>-installer-linux-x86_64.sh
    • ./bazel-<version>-installer-linux-x86_64.sh --user
    • --user标志表示: Bazel 安装在 $HOME/bin 目录下, 并将.bazelrc安装在$HOME/.bazelrc
  • ~/.bashrc 最后添加可执行文件的路径
    • export PATH="$PATH:$HOME/bin"

2、使用 transform_graph 进行优化

"""
removes all of the nodes that aren't called during inference, shrinks expressions that are always 
constant into single nodes, and optimizes away some multiply operations used during batch normalization 
by pre-multiplying the weights for convolutions.
"""


# optimize graph, 要先编译一下相应的工具,编译一次就行了
bazel build tensorflow/tools/graph_transforms:transform_graph && \
bazel-bin/tensorflow/tools/graph_transforms/transform_graph \
--in_graph=frozen_graph.pb \
--out_graph=optimized_graph.pb \
--inputs='ExpandDims_1' \
--outputs='decode/SparseToDense' \
--transforms='
  strip_unused_nodes(type=float, shape="1,48,160,3")  # 注意这里要改成自己输入的大小
  remove_nodes(op=Identity, op=CheckNumerics)
  fold_constants(ignore_errors=true)
  fold_batch_norms
  fold_old_batch_norms'

2018-07-06 14:33:27.616294: I tensorflow/tools/graph_transforms/transform_graph.cc:318] Applying strip_unused_nodes
2018-07-06 14:33:27.623069: I tensorflow/tools/graph_transforms/transform_graph.cc:318] Applying remove_nodes
2018-07-06 14:33:27.654728: I tensorflow/tools/graph_transforms/transform_graph.cc:318] Applying fold_constants
2018-07-06 14:33:27.678252: I tensorflow/tools/graph_transforms/transform_graph.cc:318] Applying fold_batch_norms
2018-07-06 14:33:27.683107: I tensorflow/tools/graph_transforms/transform_graph.cc:318] Applying fold_old_batch_norms

四、参考资料

1、Installing Bazel on Ubuntu
2、TensorFlow Graph Transform Tool
3、TensorFlow python tools freeze_graph.py

猜你喜欢

转载自blog.csdn.net/mzpmzk/article/details/80984967