tensorflow实战:由ckpt转成pb步骤的总结

最近真的项目很紧,996模式开启了多日,在完成项目一小步后,总算有时间进行一些梳理,也是对tensorflow有了更多的认识。

之前在学校实验室中,其实并不涉及太多tensorflow c++端的接口,无论是训练还是做inference,基本上都是python,相对来说还是比较简单的,但是c++还是会比python多很多步骤,第一个遇到的问题其实就是,我怎么使用python训练好的ckpt模型呢?

下面就对这个步骤进行梳理

1. 确认训练的模型准确

换句话说,训练出来的模型都要先进行python版本的inference,确认训练的模型无误。给模型输入设置对应的占位符,为什么?

因为保存的ckpt模型只有你训练出来的网络参数,没有给输入留位置,换句话说你的tensor进不去图里,所以要给模型设置对应的占位符,这样你的输入才有位置,你的输入也才能传到网络中去。

            
 Input_img = tf.placeholder(tf.float32, [None, height, width, channel],
                                          name="input_img")

通常情况下,定义的输入都是这样的操作,为了定义多batchsize,所以把tensor的第一个维度设置成None,这个应该还挺好理解的。

下面给就是定义你的inference操作,因为LZ的inference操作还比较麻烦,有些可能网络直接sess run一下就可以了,在inference结束后,需要将你的变量和图输出为pb型,以备后续给tensorflow c++接口使用,opencv好像也行,这个就没仔细研究过了。

使用graph_io接口保存成pb模型,然后也重新保存下ckpt。

 graph_io.write_graph(R_sess.graph, './model_pb', "name.pb", as_text=False)
 saver.save(R_sess, os.path.join('./model_pb', 'name.ckpt'))

这里也是总结一番吧,通常我们在训练的时候,会保存比较好的模型,也就是常说的设置checkpoint,在训练过程中,在保存的文件夹中会有四个文件,checkpoint, ckpt-data, ckpt-meta, ckpt-index,其中,checkpoint会保存你模型名字等,ckpt-data其实就是变量的具体数值,ckpt-meta就是保存了图的结构,换句话说就是你的网络结构,ckpt-index应该就是 保存了变量值和网络结构变量名的对应关系(这个不是很确定,因为一般好像用不到)

2. 确认你的输入输出节点

这里定义了一个函数,用来使用tensorboard来看pb中图的结构,从而也可以确认网络的输入输出节点,只有确认了输入输出节点,才能在后续固定参数,优化对应的图结构。

def import_to_tensorboard(model_dir, log_dir, frozen_graph):
    """View an imported protobuf model (`.pb` file) as a graph in Tensorboard.
    Args:
      model_dir: The location of the protobuf (`pb`) model to visualize
      log_dir: The location for the Tensorboard log to begin visualization from.
      frozen_graph: frozen flag, if frozen set true, else set false.
    Usage:
      Call this function with your model location and desired log directory.
      Launch Tensorboard by pointing it to the log directory.
      View your imported `.pb` model as a graph.
    """
    with session.Session(graph=ops.Graph()) as sess:
        with gfile.FastGFile(model_dir, "rb") as f:
            graph_def = graph_pb2.GraphDef()
            data = f.read()
            if frozen_graph:
                graph_def.ParseFromString(data)
            else:
                text_format.Merge(data.decode("utf-8"), graph_def)
            importer.import_graph_def(graph_def)
        pb_visual_writer = summary.FileWriter(log_dir)
        pb_visual_writer.add_graph(sess.graph)
        print("Model Imported. Visualize by running: "
              "> tensorboard --logdir={}".format(log_dir))

3. freeze graph

这个就是直接使用tensorflow中现成的函数进行参数固定,在tensorflow的源码tensorflow/python/tools的文件夹中。

# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
r"""Converts checkpoint variables into Const ops in a standalone GraphDef file.

This script is designed to take a GraphDef proto, a SaverDef proto, and a set of
variable values stored in a checkpoint file, and output a GraphDef with all of
the variable ops converted into const ops containing the values of the
variables.

It's useful to do this when we need to load a single file in C++, especially in
environments like mobile or embedded where we may not have access to the
RestoreTensor ops and file loading calls that they rely on.

An example of command-line usage is:
bazel build tensorflow/python/tools:freeze_graph && \
bazel-bin/tensorflow/python/tools/freeze_graph \
--input_graph=some_graph_def.pb \
--input_checkpoint=model.ckpt-8361242 \
--output_graph=/tmp/frozen_graph.pb --output_node_names=softmax

You can also look at freeze_graph_test.py for an example of how to use it.

"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import argparse
import re
import sys

from google.protobuf import text_format

from tensorflow.core.framework import graph_pb2
from tensorflow.core.protobuf import saver_pb2
from tensorflow.core.protobuf.meta_graph_pb2 import MetaGraphDef
from tensorflow.python import pywrap_tensorflow
from tensorflow.python.client import session
from tensorflow.python.framework import graph_util
from tensorflow.python.framework import importer
from tensorflow.python.platform import app
from tensorflow.python.platform import gfile
from tensorflow.python.saved_model import loader
from tensorflow.python.saved_model import tag_constants
from tensorflow.python.tools import saved_model_utils
from tensorflow.python.training import checkpoint_management
from tensorflow.python.training import saver as saver_lib


def _has_no_variables(sess):
  """Determines if the graph has any variables.

  Args:
    sess: TensorFlow Session.

  Returns:
    Bool.
  """
  for op in sess.graph.get_operations():
    if op.type.startswith("Variable") or op.type.endswith("VariableOp"):
      return False
  return True


def freeze_graph_with_def_protos(input_graph_def,
                                 input_saver_def,
                                 input_checkpoint,
                                 output_node_names,
                                 restore_op_name,
                                 filename_tensor_name,
                                 output_graph,
                                 clear_devices,
                                 initializer_nodes,
                                 variable_names_whitelist="",
                                 variable_names_blacklist="",
                                 input_meta_graph_def=None,
                                 input_saved_model_dir=None,
                                 saved_model_tags=None,
                                 checkpoint_version=saver_pb2.SaverDef.V2):
  """Converts all variables in a graph and checkpoint into constants.

  Args:
    input_graph_def: A `GraphDef`.
    input_saver_def: A `SaverDef` (optional).
    input_checkpoint: The prefix of a V1 or V2 checkpoint, with V2 taking
      priority.  Typically the result of `Saver.save()` or that of
      `tf.train.latest_checkpoint()`, regardless of sharded/non-sharded or
      V1/V2.
    output_node_names: The name(s) of the output nodes, comma separated.
    restore_op_name: Unused.
    filename_tensor_name: Unused.
    output_graph: String where to write the frozen `GraphDef`.
    clear_devices: A Bool whether to remove device specifications.
    initializer_nodes: Comma separated string of initializer nodes to run before
                       freezing.
    variable_names_whitelist: The set of variable names to convert (optional, by
                              default, all variables are converted).
    variable_names_blacklist: The set of variable names to omit converting
                              to constants (optional).
    input_meta_graph_def: A `MetaGraphDef` (optional),
    input_saved_model_dir: Path to the dir with TensorFlow 'SavedModel' file
                           and variables (optional).
    saved_model_tags: Group of comma separated tag(s) of the MetaGraphDef to
                      load, in string format (optional).
    checkpoint_version: Tensorflow variable file format (saver_pb2.SaverDef.V1
                        or saver_pb2.SaverDef.V2)

  Returns:
    Location of the output_graph_def.
  """
  del restore_op_name, filename_tensor_name  # Unused by updated loading code.

  # 'input_checkpoint' may be a prefix if we're using Saver V2 format
  if (not input_saved_model_dir and
      not checkpoint_management.checkpoint_exists(input_checkpoint)):
    raise ValueError("Input checkpoint '" + input_checkpoint +
                     "' doesn't exist!")

  if not output_node_names:
    raise ValueError(
        "You need to supply the name of a node to --output_node_names.")

  # Remove all the explicit device specifications for this node. This helps to
  # make the graph more portable.
  if clear_devices:
    if input_meta_graph_def:
      for node in input_meta_graph_def.graph_def.node:
        node.device = ""
    elif input_graph_def:
      for node in input_graph_def.node:
        node.device = ""

  if input_graph_def:
    _ = importer.import_graph_def(input_graph_def, name="")
  with session.Session() as sess:
    if input_saver_def:
      saver = saver_lib.Saver(
          saver_def=input_saver_def, write_version=checkpoint_version)
      saver.restore(sess, input_checkpoint)
    elif input_meta_graph_def:
      restorer = saver_lib.import_meta_graph(
          input_meta_graph_def, clear_devices=True)
      restorer.restore(sess, input_checkpoint)
      if initializer_nodes:
        sess.run(initializer_nodes.replace(" ", "").split(","))
    elif input_saved_model_dir:
      if saved_model_tags is None:
        saved_model_tags = []
      loader.load(sess, saved_model_tags, input_saved_model_dir)
    else:
      var_list = {}
      reader = pywrap_tensorflow.NewCheckpointReader(input_checkpoint)
      var_to_shape_map = reader.get_variable_to_shape_map()

      # List of all partition variables. Because the condition is heuristic
      # based, the list could include false positives.
      all_parition_variable_names = [
          tensor.name.split(":")[0]
          for op in sess.graph.get_operations()
          for tensor in op.values()
          if re.search(r"/part_\d+/", tensor.name)
      ]
      has_partition_var = False

      for key in var_to_shape_map:
        try:
          tensor = sess.graph.get_tensor_by_name(key + ":0")
          if any(key in name for name in all_parition_variable_names):
            has_partition_var = True
        except KeyError:
          # This tensor doesn't exist in the graph (for example it's
          # 'global_step' or a similar housekeeping element) so skip it.
          continue
        var_list[key] = tensor

      try:
        saver = saver_lib.Saver(
            var_list=var_list, write_version=checkpoint_version)
      except TypeError as e:
        # `var_list` is required to be a map of variable names to Variable
        # tensors. Partition variables are Identity tensors that cannot be
        # handled by Saver.
        if has_partition_var:
          raise ValueError(
              "Models containing partition variables cannot be converted "
              "from checkpoint files. Please pass in a SavedModel using "
              "the flag --input_saved_model_dir.")
        # Models that have been frozen previously do not contain Variables.
        elif _has_no_variables(sess):
          raise ValueError(
              "No variables were found in this model. It is likely the model "
              "was frozen previously. You cannot freeze a graph twice.")
          return 0
        else:
          raise e

      saver.restore(sess, input_checkpoint)
      if initializer_nodes:
        sess.run(initializer_nodes.replace(" ", "").split(","))

    variable_names_whitelist = (
        variable_names_whitelist.replace(" ", "").split(",")
        if variable_names_whitelist else None)
    variable_names_blacklist = (
        variable_names_blacklist.replace(" ", "").split(",")
        if variable_names_blacklist else None)

    if input_meta_graph_def:
      output_graph_def = graph_util.convert_variables_to_constants(
          sess,
          input_meta_graph_def.graph_def,
          output_node_names.replace(" ", "").split(","),
          variable_names_whitelist=variable_names_whitelist,
          variable_names_blacklist=variable_names_blacklist)
    else:
      output_graph_def = graph_util.convert_variables_to_constants(
          sess,
          input_graph_def,
          output_node_names.replace(" ", "").split(","),
          variable_names_whitelist=variable_names_whitelist,
          variable_names_blacklist=variable_names_blacklist)

  # Write GraphDef to file if output path has been given.
  if output_graph:
    with gfile.GFile(output_graph, "wb") as f:
      f.write(output_graph_def.SerializeToString())

  return output_graph_def


def _parse_input_graph_proto(input_graph, input_binary):
  """Parses input tensorflow graph into GraphDef proto."""
  if not gfile.Exists(input_graph):
    raise IOError("Input graph file '" + input_graph + "' does not exist!")
  input_graph_def = graph_pb2.GraphDef()
  mode = "rb" if input_binary else "r"
  with gfile.GFile(input_graph, mode) as f:
    if input_binary:
      input_graph_def.ParseFromString(f.read())
    else:
      text_format.Merge(f.read(), input_graph_def)
  return input_graph_def


def _parse_input_meta_graph_proto(input_graph, input_binary):
  """Parses input tensorflow graph into MetaGraphDef proto."""
  if not gfile.Exists(input_graph):
    raise IOError("Input meta graph file '" + input_graph + "' does not exist!")
  input_meta_graph_def = MetaGraphDef()
  mode = "rb" if input_binary else "r"
  with gfile.GFile(input_graph, mode) as f:
    if input_binary:
      input_meta_graph_def.ParseFromString(f.read())
    else:
      text_format.Merge(f.read(), input_meta_graph_def)
  print("Loaded meta graph file '" + input_graph)
  return input_meta_graph_def


def _parse_input_saver_proto(input_saver, input_binary):
  """Parses input tensorflow Saver into SaverDef proto."""
  if not gfile.Exists(input_saver):
    raise IOError("Input saver file '" + input_saver + "' does not exist!")
  mode = "rb" if input_binary else "r"
  with gfile.GFile(input_saver, mode) as f:
    saver_def = saver_pb2.SaverDef()
    if input_binary:
      saver_def.ParseFromString(f.read())
    else:
      text_format.Merge(f.read(), saver_def)
  return saver_def


def freeze_graph(input_graph,
                 input_saver,
                 input_binary,
                 input_checkpoint,
                 output_node_names,
                 restore_op_name,
                 filename_tensor_name,
                 output_graph,
                 clear_devices,
                 initializer_nodes,
                 variable_names_whitelist="",
                 variable_names_blacklist="",
                 input_meta_graph=None,
                 input_saved_model_dir=None,
                 saved_model_tags=tag_constants.SERVING,
                 checkpoint_version=saver_pb2.SaverDef.V2):
  """Converts all variables in a graph and checkpoint into constants.

  Args:
    input_graph: A `GraphDef` file to load.
    input_saver: A TensorFlow Saver file.
    input_binary: A Bool. True means input_graph is .pb, False indicates .pbtxt.
    input_checkpoint: The prefix of a V1 or V2 checkpoint, with V2 taking
      priority.  Typically the result of `Saver.save()` or that of
      `tf.train.latest_checkpoint()`, regardless of sharded/non-sharded or
      V1/V2.
    output_node_names: The name(s) of the output nodes, comma separated.
    restore_op_name: Unused.
    filename_tensor_name: Unused.
    output_graph: String where to write the frozen `GraphDef`.
    clear_devices: A Bool whether to remove device specifications.
    initializer_nodes: Comma separated list of initializer nodes to run before
                       freezing.
    variable_names_whitelist: The set of variable names to convert (optional, by
                              default, all variables are converted),
    variable_names_blacklist: The set of variable names to omit converting
                              to constants (optional).
    input_meta_graph: A `MetaGraphDef` file to load (optional).
    input_saved_model_dir: Path to the dir with TensorFlow 'SavedModel' file and
                           variables (optional).
    saved_model_tags: Group of comma separated tag(s) of the MetaGraphDef to
                      load, in string format.
    checkpoint_version: Tensorflow variable file format (saver_pb2.SaverDef.V1
                        or saver_pb2.SaverDef.V2).
  Returns:
    String that is the location of frozen GraphDef.
  """
  input_graph_def = None
  if input_saved_model_dir:
    input_graph_def = saved_model_utils.get_meta_graph_def(
        input_saved_model_dir, saved_model_tags).graph_def
  elif input_graph:
    input_graph_def = _parse_input_graph_proto(input_graph, input_binary)
  input_meta_graph_def = None
  if input_meta_graph:
    input_meta_graph_def = _parse_input_meta_graph_proto(
        input_meta_graph, input_binary)
  input_saver_def = None
  if input_saver:
    input_saver_def = _parse_input_saver_proto(input_saver, input_binary)
  return freeze_graph_with_def_protos(
      input_graph_def,
      input_saver_def,
      input_checkpoint,
      output_node_names,
      restore_op_name,
      filename_tensor_name,
      output_graph,
      clear_devices,
      initializer_nodes,
      variable_names_whitelist,
      variable_names_blacklist,
      input_meta_graph_def,
      input_saved_model_dir,
      saved_model_tags.replace(" ", "").split(","),
      checkpoint_version=checkpoint_version)


def main(unused_args, flags):
  if flags.checkpoint_version == 1:
    checkpoint_version = saver_pb2.SaverDef.V1
  elif flags.checkpoint_version == 2:
    checkpoint_version = saver_pb2.SaverDef.V2
  else:
    raise ValueError("Invalid checkpoint version (must be '1' or '2'): %d" %
                     flags.checkpoint_version)
  freeze_graph(flags.input_graph, flags.input_saver, flags.input_binary,
               flags.input_checkpoint, flags.output_node_names,
               flags.restore_op_name, flags.filename_tensor_name,
               flags.output_graph, flags.clear_devices, flags.initializer_nodes,
               flags.variable_names_whitelist, flags.variable_names_blacklist,
               flags.input_meta_graph, flags.input_saved_model_dir,
               flags.saved_model_tags, checkpoint_version)


def run_main():
  """Main function of freeze_graph."""
  parser = argparse.ArgumentParser()
  parser.register("type", "bool", lambda v: v.lower() == "true")
  parser.add_argument(
      "--input_graph",
      type=str,
      default="",
      help="TensorFlow \'GraphDef\' file to load.")
  parser.add_argument(
      "--input_saver",
      type=str,
      default="",
      help="TensorFlow saver file to load.")
  parser.add_argument(
      "--input_checkpoint",
      type=str,
      default="",
      help="TensorFlow variables file to load.")
  parser.add_argument(
      "--checkpoint_version",
      type=int,
      default=2,
      help="Tensorflow variable file format")
  parser.add_argument(
      "--output_graph",
      type=str,
      default="",
      help="Output \'GraphDef\' file name.")
  parser.add_argument(
      "--input_binary",
      nargs="?",
      const=True,
      type="bool",
      default=False,
      help="Whether the input files are in binary format.")
  parser.add_argument(
      "--output_node_names",
      type=str,
      default="",
      help="The name of the output nodes, comma separated.")
  parser.add_argument(
      "--restore_op_name",
      type=str,
      default="save/restore_all",
      help="""\
      The name of the master restore operator. Deprecated, unused by updated \
      loading code.
      """)
  parser.add_argument(
      "--filename_tensor_name",
      type=str,
      default="save/Const:0",
      help="""\
      The name of the tensor holding the save path. Deprecated, unused by \
      updated loading code.
      """)
  parser.add_argument(
      "--clear_devices",
      nargs="?",
      const=True,
      type="bool",
      default=True,
      help="Whether to remove device specifications.")
  parser.add_argument(
      "--initializer_nodes",
      type=str,
      default="",
      help="Comma separated list of initializer nodes to run before freezing.")
  parser.add_argument(
      "--variable_names_whitelist",
      type=str,
      default="",
      help="""\
      Comma separated list of variables to convert to constants. If specified, \
      only those variables will be converted to constants.\
      """)
  parser.add_argument(
      "--variable_names_blacklist",
      type=str,
      default="",
      help="""\
      Comma separated list of variables to skip converting to constants.\
      """)
  parser.add_argument(
      "--input_meta_graph",
      type=str,
      default="",
      help="TensorFlow \'MetaGraphDef\' file to load.")
  parser.add_argument(
      "--input_saved_model_dir",
      type=str,
      default="",
      help="Path to the dir with TensorFlow \'SavedModel\' file and variables.")
  parser.add_argument(
      "--saved_model_tags",
      type=str,
      default="serve",
      help="""\
      Group of tag(s) of the MetaGraphDef to load, in string format,\
      separated by \',\'. For tag-set contains multiple tags, all tags \
      must be passed in.\
      """)
  flags, unparsed = parser.parse_known_args()

  my_main = lambda unused_args: main(unused_args, flags)
  app.run(main=my_main, argv=[sys.argv[0]] + unparsed)


if __name__ == "__main__":
  run_main()

这个函数中有很多参数,

  1. input_graph:你要下载的模型的图文件,可以是pb或者meta类型的
  2. input_saver:就是下载tensorflow对应的saver文件,但是LZ看了好多例子,基本上这个都是可以省略的
  3. input_checkpoint:输入你对应checkpoint的位置
  4. checkpoint_version:这是版本问题,默认为2,
  5. output_graph:输出图的文件名称
  6. input_binary:是否为二进制的输入,
  7. output_node_names: 输出节点的名字,有多个输出时用逗号分开
  8. restore_op_name:这个好像也不怎么用,默认为save/restore_all
  9. filename_tensorflow_name:这个也不用了
  10. clear_device:默认是true,清除训练时指定的设备,如之前指定使用哪块GPU
  11. initializer_nodes:可以来初始化对应节点,如果有多个节点还是使用逗号分开,这个用的也很少。
  12. variable_name_white:指定进行freeze的变量名单,多个也使用逗号分开,如果不指定,默认freeze全部变量
  13. variable_name_blacklist:不用进行freeze的变量名单,多个也使用逗号分开
  14. input_meta_graph:输入meta对应地址
  15. input_saved_model_dir:save_model文件和变量地址
  16. saved_model_tags:要加载MetaGraph的标签组,以字符串格式,如果存在多个标签,使用逗号分开,如果标签集包含多个标签,则必须传递所有标签。

其实看看很多参数,实际上用起来,最主要的格式应该是:

python freeze_graph.py --input_graph = path to you pb 
                       --input_checkpoint = path to your ckpt
                       --output_graph = path to your output pb 
                       --output_node_name = name of your output node

后面根据自己的需求进行对应的设置

4.optimize pb

其实在训练的时候有些模型的参数对于inference阶段是冗余的,所以需要把对应的模型进行优化,主要是删掉多余的节点

def opt_freezed_pb(tmp_dir,
                   input_node_names,
                   output_node_names,
                   input_graph_name,
                   output_graph_name):
    input_graph_path = os.path.join(tmp_dir, input_graph_name)

    input_graph_def = graph_pb2.GraphDef()

    with gfile.Open(input_graph_path, "rb") as f:
        data = f.read()
        input_graph_def.ParseFromString(data)
        # text_format.Merge(f.read(), input_graph_def)
        # print(get_node_name_list(input_graph_def))

    output_graph_def = optimize_for_inference_lib.optimize_for_inference(
        input_graph_def,
        input_node_names.split(","),
        output_node_names.split(","), dtypes.float32.as_datatype_enum)

    output_graph_filename = os.path.join(tmp_dir, output_graph_name)
    f = gfile.FastGFile(output_graph_filename, "w")
    f.write(output_graph_def.SerializeToString())

小伙伴可以通过tensorboard对比出具体的区别

5.优化后的pb模型检验

最后在python下,使用优化好的pb模型进行检验,看是否和原始ckpt模型inference结果一致。

6.最后当然是使用tensorflow c++接口进行测试啦!

发布了300 篇原创文章 · 获赞 203 · 访问量 59万+

猜你喜欢

转载自blog.csdn.net/Felaim/article/details/101554335