Uso Tensorflow conversión de formato de archivo de modelo y -OpenCV DNN pueden llamar formato (a)

Hace unos días para ver la mayor OpenCV integrada, que se encuentra en frente de manera muy complicado, y ahora en busca de información y la investigación encontró que este método es muy bueno.

expansión

Tensorflow gráfico modelo de estructura puede ser guardado como un archivo o archivo .pb .pbtxt o archivo .meta, en el que el archivo es legible solamente .pbtxt.

Daniel fueron entrenados en línea de la red, guardar el modelo como un archivo .pb unificado, que posee más que en la estructura y el modelo de los nombres de las variables de la red, sino también guardar los valores de todas las variables, si quieren aprovecharse de otras personas capacitadas modelo para poner a prueba sus propios datos, tienden a hacer algunos cambios en este modelo, a continuación, a menudo necesitamos saber el nombre del modelo tensor interior original, pero los archivos .meta .PB y documentos no se pueden leer, todo lo necesario éstos conversión de formato de archivo de dos.

archivo ①.meta

En este caso, por lo general requiere varios otros archivo de controles, puesto de control, model.cpkt.index, model.cpkt.data similares, se puede utilizar para instalar tensofrflow / home / zhaixingzhe / tensorflow directorio / tensorflow / Python / herramientas / inspect_checkpoint .py nombre de archivo de salida de impresión del modelo para todos tensor (tensor) y la operación (OP), la siguiente es todo el inspect_checkpoint.py código:

# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""A simple script for inspect checkpoint files."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
 
import argparse
import sys
 
import numpy as np
 
from tensorflow.python import pywrap_tensorflow
from tensorflow.python.platform import app
from tensorflow.python.platform import flags
 
FLAGS = None
 
 
def print_tensors_in_checkpoint_file(file_name, tensor_name, all_tensors):
  """Prints tensors in a checkpoint file.
  If no `tensor_name` is provided, prints the tensor names and shapes
  in the checkpoint file.
  If `tensor_name` is provided, prints the content of the tensor.
  Args:
    file_name: Name of the checkpoint file.
    tensor_name: Name of the tensor in the checkpoint file to print.
    all_tensors: Boolean indicating whether to print all tensors.
  """
  try:
    reader = pywrap_tensorflow.NewCheckpointReader(file_name)
    if all_tensors:
      var_to_shape_map = reader.get_variable_to_shape_map()
      for key in sorted(var_to_shape_map):
        print("tensor_name: ", key)
        print(reader.get_tensor(key))
    elif not tensor_name:
      print(reader.debug_string().decode("utf-8"))
    else:
      print("tensor_name: ", tensor_name)
      print(reader.get_tensor(tensor_name))
  except Exception as e:  # pylint: disable=broad-except
    print(str(e))
    if "corrupted compressed block contents" in str(e):
      print("It's likely that your checkpoint file has been compressed "
            "with SNAPPY.")
    if ("Data loss" in str(e) and
        (any([e in file_name for e in [".index", ".meta", ".data"]]))):
      proposed_file = ".".join(file_name.split(".")[0:-1])
      v2_file_error_template = """
It's likely that this is a V2 checkpoint and you need to provide the filename
*prefix*.  Try removing the '.' and extension.  Try:
inspect checkpoint --file_name = {}"""
      print(v2_file_error_template.format(proposed_file))
 
 
def parse_numpy_printoption(kv_str):
  """Sets a single numpy printoption from a string of the form 'x=y'.
  See documentation on numpy.set_printoptions() for details about what values
  x and y can take. x can be any option listed there other than 'formatter'.
  Args:
    kv_str: A string of the form 'x=y', such as 'threshold=100000'
  Raises:
    argparse.ArgumentTypeError: If the string couldn't be used to set any
        nump printoption.
  """
  k_v_str = kv_str.split("=", 1)
  if len(k_v_str) != 2 or not k_v_str[0]:
    raise argparse.ArgumentTypeError("'%s' is not in the form k=v." % kv_str)
  k, v_str = k_v_str
  printoptions = np.get_printoptions()
  if k not in printoptions:
    raise argparse.ArgumentTypeError("'%s' is not a valid printoption." % k)
  v_type = type(printoptions[k])
  if v_type is type(None):
    raise argparse.ArgumentTypeError(
        "Setting '%s' from the command line is not supported." % k)
  try:
    v = (v_type(v_str) if v_type is not bool
         else flags.BooleanParser().parse(v_str))
  except ValueError as e:
    raise argparse.ArgumentTypeError(e.message)
  np.set_printoptions(**{k: v})
 
 
def main(unused_argv):
  if not FLAGS.file_name:
    print("Usage: inspect_checkpoint --file_name=checkpoint_file_name "
          "[--tensor_name=tensor_to_print]")
    sys.exit(1)
  else:
    print_tensors_in_checkpoint_file(FLAGS.file_name, FLAGS.tensor_name,
                                     FLAGS.all_tensors)
 
 
if __name__ == "__main__":
  parser = argparse.ArgumentParser()
  parser.register("type", "bool", lambda v: v.lower() == "true")
  parser.add_argument(
      "--file_name", type=str, default="", help="Checkpoint filename. "
                    "Note, if using Checkpoint V2 format, file_name is the "
                    "shared prefix between all files in the checkpoint.")
  parser.add_argument(
      "--tensor_name",
      type=str,
      default="",
      help="Name of the tensor to inspect")
  parser.add_argument(
      "--all_tensors",
      nargs="?",
      const=True,
      type="bool",
      default=False,
      help="If True, print the values of all the tensors.")
  parser.add_argument(
      "--printoptions",
      nargs="*",
      type=parse_numpy_printoption,
      help="Argument for numpy.set_printoptions(), in the form 'k=v'.")
  FLAGS, unparsed = parser.parse_known_args()
  app.run(main=main, argv=[sys.argv[0]] + unparsed)

archivo ②.pb

El siguiente código define dos funciones, archivos y de conversión entre los archivos .pbtxt .PB pueden ser implementados.

 
import tensorflow as tf
from tensorflow.python.platform import gfile
from google.protobuf import text_format
 
def convert_pb_to_pbtxt(filename):
  with gfile.FastGFile(filename,'rb') as f:
    graph_def = tf.GraphDef()
 
    graph_def.ParseFromString(f.read())
 
    tf.import_graph_def(graph_def, name='')
 
    tf.train.write_graph(graph_def, './', 'protobuf.pbtxt', as_text=True)
  return
 
def convert_pbtxt_to_pb(filename):
  """Returns a `tf.GraphDef` proto representing the data in the given pbtxt file.
  Args:
    filename: The name of a file containing a GraphDef pbtxt (text-formatted
      `tf.GraphDef` protocol buffer data).
  """
  with tf.gfile.FastGFile(filename, 'r') as f:
    graph_def = tf.GraphDef()
 
    file_content = f.read()
 
    # Merges the human-readable string in `file_content` into `graph_def`.
    text_format.Merge(file_content, graph_def)
    tf.train.write_graph( graph_def , './' , 'protobuf.pb' , as_text = False )
  return

 

Publicado 86 artículos originales · ganado elogios 267 · Vistas 1,77 millones +

Supongo que te gusta

Origin blog.csdn.net/javastart/article/details/104741884
Recomendado
Clasificación