Tensorflow 模型文件的使用以及格式转换-OpenCV DNN 可调用格式(一)

主要前几天看到集成opencv,发现前面的方式比较麻烦,现在就找资料研究,发现这个方式非常不错。

展开

Tensorflow模型的graph结构可以保存为.pb文件或者.pbtxt文件,或者.meta文件,其中只有.pbtxt文件是可读的。

网上大牛们训练好的网络,将模型保存为一个统一的.pb文件,这个文件中不止保存着模型网络的结构和变量名,还保存了所有变量的值,如果我们想利用别人训练好的模型对自己的数据进行测试,往往要对这个模型做一些修改,这时我们经常需要知道原有模型里面的一些张量名称,但是.pb文件和.meta文件都是不可读的,所有有必要对这两种文件进行格式转换。

①.meta文件

这种情况下,通常还需要其他几个checkpoint文件,checkpoint ,model.cpkt.index,model.cpkt.data 等,可以使用tensofrflow安装目录下的 /home/zhaixingzhe/tensorflow/tensorflow/python/tools/inspect_checkpoint.py 文件打印输出模型中所有张量(tensor)和操作(op)的名称,下面是inspect_checkpoint.py的全部代码:

# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""A simple script for inspect checkpoint files."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
 
import argparse
import sys
 
import numpy as np
 
from tensorflow.python import pywrap_tensorflow
from tensorflow.python.platform import app
from tensorflow.python.platform import flags
 
FLAGS = None
 
 
def print_tensors_in_checkpoint_file(file_name, tensor_name, all_tensors):
  """Prints tensors in a checkpoint file.
  If no `tensor_name` is provided, prints the tensor names and shapes
  in the checkpoint file.
  If `tensor_name` is provided, prints the content of the tensor.
  Args:
    file_name: Name of the checkpoint file.
    tensor_name: Name of the tensor in the checkpoint file to print.
    all_tensors: Boolean indicating whether to print all tensors.
  """
  try:
    reader = pywrap_tensorflow.NewCheckpointReader(file_name)
    if all_tensors:
      var_to_shape_map = reader.get_variable_to_shape_map()
      for key in sorted(var_to_shape_map):
        print("tensor_name: ", key)
        print(reader.get_tensor(key))
    elif not tensor_name:
      print(reader.debug_string().decode("utf-8"))
    else:
      print("tensor_name: ", tensor_name)
      print(reader.get_tensor(tensor_name))
  except Exception as e:  # pylint: disable=broad-except
    print(str(e))
    if "corrupted compressed block contents" in str(e):
      print("It's likely that your checkpoint file has been compressed "
            "with SNAPPY.")
    if ("Data loss" in str(e) and
        (any([e in file_name for e in [".index", ".meta", ".data"]]))):
      proposed_file = ".".join(file_name.split(".")[0:-1])
      v2_file_error_template = """
It's likely that this is a V2 checkpoint and you need to provide the filename
*prefix*.  Try removing the '.' and extension.  Try:
inspect checkpoint --file_name = {}"""
      print(v2_file_error_template.format(proposed_file))
 
 
def parse_numpy_printoption(kv_str):
  """Sets a single numpy printoption from a string of the form 'x=y'.
  See documentation on numpy.set_printoptions() for details about what values
  x and y can take. x can be any option listed there other than 'formatter'.
  Args:
    kv_str: A string of the form 'x=y', such as 'threshold=100000'
  Raises:
    argparse.ArgumentTypeError: If the string couldn't be used to set any
        nump printoption.
  """
  k_v_str = kv_str.split("=", 1)
  if len(k_v_str) != 2 or not k_v_str[0]:
    raise argparse.ArgumentTypeError("'%s' is not in the form k=v." % kv_str)
  k, v_str = k_v_str
  printoptions = np.get_printoptions()
  if k not in printoptions:
    raise argparse.ArgumentTypeError("'%s' is not a valid printoption." % k)
  v_type = type(printoptions[k])
  if v_type is type(None):
    raise argparse.ArgumentTypeError(
        "Setting '%s' from the command line is not supported." % k)
  try:
    v = (v_type(v_str) if v_type is not bool
         else flags.BooleanParser().parse(v_str))
  except ValueError as e:
    raise argparse.ArgumentTypeError(e.message)
  np.set_printoptions(**{k: v})
 
 
def main(unused_argv):
  if not FLAGS.file_name:
    print("Usage: inspect_checkpoint --file_name=checkpoint_file_name "
          "[--tensor_name=tensor_to_print]")
    sys.exit(1)
  else:
    print_tensors_in_checkpoint_file(FLAGS.file_name, FLAGS.tensor_name,
                                     FLAGS.all_tensors)
 
 
if __name__ == "__main__":
  parser = argparse.ArgumentParser()
  parser.register("type", "bool", lambda v: v.lower() == "true")
  parser.add_argument(
      "--file_name", type=str, default="", help="Checkpoint filename. "
                    "Note, if using Checkpoint V2 format, file_name is the "
                    "shared prefix between all files in the checkpoint.")
  parser.add_argument(
      "--tensor_name",
      type=str,
      default="",
      help="Name of the tensor to inspect")
  parser.add_argument(
      "--all_tensors",
      nargs="?",
      const=True,
      type="bool",
      default=False,
      help="If True, print the values of all the tensors.")
  parser.add_argument(
      "--printoptions",
      nargs="*",
      type=parse_numpy_printoption,
      help="Argument for numpy.set_printoptions(), in the form 'k=v'.")
  FLAGS, unparsed = parser.parse_known_args()
  app.run(main=main, argv=[sys.argv[0]] + unparsed)

②.pb文件

下面的代码定义了两个函数,可以实现.pb文件和.pbtxt文件之间的转换。

 
import tensorflow as tf
from tensorflow.python.platform import gfile
from google.protobuf import text_format
 
def convert_pb_to_pbtxt(filename):
  with gfile.FastGFile(filename,'rb') as f:
    graph_def = tf.GraphDef()
 
    graph_def.ParseFromString(f.read())
 
    tf.import_graph_def(graph_def, name='')
 
    tf.train.write_graph(graph_def, './', 'protobuf.pbtxt', as_text=True)
  return
 
def convert_pbtxt_to_pb(filename):
  """Returns a `tf.GraphDef` proto representing the data in the given pbtxt file.
  Args:
    filename: The name of a file containing a GraphDef pbtxt (text-formatted
      `tf.GraphDef` protocol buffer data).
  """
  with tf.gfile.FastGFile(filename, 'r') as f:
    graph_def = tf.GraphDef()
 
    file_content = f.read()
 
    # Merges the human-readable string in `file_content` into `graph_def`.
    text_format.Merge(file_content, graph_def)
    tf.train.write_graph( graph_def , './' , 'protobuf.pb' , as_text = False )
  return
发布了86 篇原创文章 · 获赞 267 · 访问量 177万+

猜你喜欢

转载自blog.csdn.net/javastart/article/details/104741884