Guardado del modelo de Tensorflow, modificación de nodos y optimización de gráficos de servicio

Guardado del modelo de Tensorflow, modificación de nodos y optimización de gráficos de servicio

Prólogo (no relacionado con el texto, puede ignorarse)

En un futuro cercano, tengo la intención de resumir algunos conocimientos básicos de Tensorflow para facilitar la referencia. La motivación para escribir este artículo es considerar un pequeño problema: a menudo usamos tf.datauna serie de Los gráficos suelen ser nodos iteradores (como la llamada tf.data.make_one_shot_iteratory el get_next()método ), pero cuando estaba en servicio, estaba pensando en cómo lidiar con los nodos de entrada y cómo tf.placeholderagregar nuevos al gráfico de servicio.

Un método es reescribir el gráfico de servicio, actualizar el nodo de entrada tf.placeholdery luego ingresarlo en el modelo para generar un nuevo gráfico; pero espero que haya un método más conciso, como si es posible reemplazar directamente el nodo de entrada del iterador. tf.placeholderwith Incluso si no sé cómo está escrito el código del modelo, puedo construir un gráfico de servicio Bajo la guía de esta pregunta, tengo un poco de comprensión profunda de conceptos como guardar y cargar modelos TF y Graph /Metagráfico.

descripción general

Este artículo presenta el método para guardar parte del modelo de Tensorflow, incluido principalmente checkpointel formato , frozen_graphel formato ( SavedModelel formato se omite temporalmente), a través del ejemplo de código para comprender el método para guardar el modelo, la optimización del gráfico de servicio y la modificación. y actualización de los nodos en el gráfico de Servicio.

código de dirección

El código de este artículo se probó con éxito en el Python 3.5.2entorno Tensorflow 1.15.0.

Todo el código de este artículo se puede descargar desde https://github.com/axzml/BlogShare/tree/master/Tensorflow/GraphDef .

Amplia publicidad

Puede buscar "Jenny's Algorithm Road" o "world4458" en WeChat, seguir mi cuenta pública de WeChat y obtener las últimas actualizaciones de artículos técnicos originales a tiempo.

Además, puede echar un vistazo a la columna de Zhihu PoorMemory-Machine Learning , y también se publicarán artículos futuros en la columna de Zhihu.

formato de punto de control

código de entrenamiento y guardar ckpt

Escribí un código de entrenamiento simple ( train.py) de la siguiente manera, completo con cinco órganos internos, que define tres funciones principales:

  • data_generator(): Genere datos falsos para participar en el entrenamiento del modelo.
  • model(): define una red neuronal simple
  • train(): Defina el código de entrenamiento, llame para guardar el modelo tf.train.Saver()en forma de punto de control
# _*_ coding:utf-8 _*_
## train.py
import tensorflow as tf
import os
import numpy as np
from os.path import join, exists

batch_size = 2
steps = 10
epochs = 1
emb_dim = 4
sample_num = epochs * steps * batch_size

checkpoint_dir = 'checkpoint_dir'
meta_name = '0'
saver_dir = join(checkpoint_dir, meta_name)

def data_generator():
	"""产生 Fake 训练数据"""
    dataset = tf.data.Dataset.from_tensor_slices((np.random.randn(sample_num, emb_dim),\
                        np.random.randn(sample_num)))
    dataset = dataset.repeat(epochs).batch(batch_size)
    iterator = tf.data.make_one_shot_iterator(dataset)
    feature, label = iterator.get_next()
    return feature, label

def model(feature, params=[10, 5, 1]):
	"""定义模型, 3层DNN"""
    fc1 = tf.layers.dense(feature, params[0], activation=tf.nn.relu, name='fc1')
    fc2 = tf.layers.dense(fc1, params[1], activation=tf.nn.relu, name='fc2')
    fc3 = tf.layers.dense(fc2, params[2], activation=tf.nn.sigmoid, name='fc3')
    out = tf.identity(fc3, name='output')
    return out

def train():
    feature, label = data_generator()
    output = model(feature)
    loss = tf.reduce_mean(tf.square(output - label))
    train_op = tf.train.AdamOptimizer(learning_rate=0.1, name='Adam').minimize(loss)
    saver = tf.train.Saver()

    if exists(checkpoint_dir):
        os.system('rm -rf {}'.format(checkpoint_dir))

    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        try:
            local_step = 0
            save_freq = 2
            while True:
                local_step += 1
                _, loss_val = sess.run([train_op, loss])
                if local_step % save_freq == 0:
                    saver.save(sess, saver_dir)
                print('loss: {:.4f}'.format(loss_val))
        except tf.errors.OutOfRangeError:
            print("train end!")


if __name__ == '__main__':
    train()

La python train.pyejecución generará checkpoint_dirun directorio bajo el directorio actual, que consta de lo siguiente:

checkpoint_dir/
|-- 0.data-00000-of-00001  ## 记录了网络参数值 
|-- 0.index  ## 记录了网络参数名
|-- 0.meta   ## 保存 MetaGraphDef, 该文件以 pb 格式记录了网络结构
`-- checkpoint  ## 该文件记录了最新的 ckpt

Cargue ckpt y verifique la estructura del gráfico

checkpointEl modelo en el formato debe cargarse en el marco de Tensorflow. Por ejemplo, escribe eval.pyen inferencia, el código es el siguiente:

#_*_ coding:utf-8 _*_
## eval.py
import tensorflow as tf
import os
from os.path import join, exists
import numpy as np

emb_dim = 4
checkpoint_dir = 'checkpoint_dir'
meta_name = '0'
saver_dir = join(checkpoint_dir, meta_name)
meta_file = saver_dir + '.meta'
model_file = tf.train.latest_checkpoint(checkpoint_dir)

np.random.seed(123)
test_data = np.random.randn(4, emb_dim) ## 生成测试数据

def eval_graph():
    with tf.Session() as sess:
        saver = tf.train.import_meta_graph(meta_file)
        saver.restore(sess, model_file)
        output = sess.run(['output:0'], feed_dict={
    
    
            'IteratorGetNext:0': test_data
        })
        print('eval_graph:\n{}'.format(output))

if __name__ == '__main__':
    eval_graph()

En el código anterior, observe que los nombres de los nodos de entrada y salida son outputy IteratorGetNextPara el nodo de salida, train.pyya que model()se usa en la función de

out = tf.identity(fc3, name='output')

Cambie el nombre del nodo de salida a output, de modo que el nombre del nodo de salida sea muy fácil de determinar. Pero el nombre del nodo de entrada no es muy fácil de determinar, porque la tf.dataAPI se usa para pasar los datos durante el entrenamiento y el nodo de entrada no tiene un nombre explícito Sin embargo, debido a que Al guardar el modelo, la estructura de la red se almacenó en 0.metael archivo , por lo que los nodos de entrada de la red se pueden ver analizando el archivo, el método específico es el siguiente:

#_*_ coding:utf-8 _*_
## check_graph.py
import tensorflow as tf
from tensorflow.python.framework import meta_graph
from tensorflow.core.protobuf.meta_graph_pb2 import MetaGraphDef
from google.protobuf import text_format

import os
from os.path import join, exists
import numpy as np

checkpoint_dir = 'checkpoint_dir'
meta_name = '0'
saver_dir = join(checkpoint_dir, meta_name)
meta_file = saver_dir + '.meta'
model_file = tf.train.latest_checkpoint(checkpoint_dir)

def read_pb_meta(meta_file):
	"""读取 pb 格式的 meta 文件"""
    meta_graph_def = meta_graph.read_meta_graph_file(meta_file)
    return meta_graph_def

def read_txt_meta(txt_meta_file):
	"""读取文本格式的 meta 文件"""
    meta_graph = MetaGraphDef()
    with open(txt_meta_file, 'rb') as f:
        text_format.Merge(f.read(), meta_graph)
    return meta_graph

def read_pb_graph(graph_file):
	"""读取 pb 格式的 graph_def 文件"""
    try:
        with tf.gfile.GFile(graph_file, 'rb') as pb:
            graph_def = tf.GraphDef()
            graph_def.ParseFromString(pb.read())
    except IOError as e:
        raise Exception("Parse '{}' Failed!".format(graph_file))
    return graph_def


def check_graph_def(graph_def):
	"""检查 graph_def 中的各节点"""
    with tf.Graph().as_default() as graph:
        tf.import_graph_def(
            graph_def,
            name=""
        )
        print('===> {}'.format(type(graph)))
        for op in graph.get_operations():
            print(op.name, op.values())  ## 打印网络结构

def check_graph(graph_file):
	"""检查 pb 格式的 graph_def 文件中的各节点"""
    graph_def = read_pb_graph(graph_file)
    check_graph_def(graph_def)
    

if __name__ == '__main__':
    check_graph_def(read_pb_meta(meta_file).graph_def)

El resultado de salida se muestra en la siguiente figura. Se puede encontrar que el nodo fc1/kernelmás cercano es IteratorGetNext, por lo que básicamente se puede confirmar que el nombre del nodo de entrada es ese.

modificación de nodo

Ahora volviendo a la pregunta mencionada en el "Prefacio", si quiero usar el tf.placeholdernodo como el nodo de entrada del Gráfico en lugar de usarlo, IteratorGetNext¿cómo debo lograrlo? Por un lado, puedo reescribir el Tensorflow Grafique y utilícelo tf.placeholdercomo entrada; por otro lado, de hecho, puede considerar reemplazar IteratorGetNetel nodo con un nodo personalizado. Este paso se refiere a la publicación del blog Cómo modificar el gráfico TF después de construirlo . El método específico es el siguiente , el código está infer.pyen :

#_*_ coding:utf-8 _*_
## infer.py
import tensorflow as tf
from tensorflow.python.framework import meta_graph
import os
from os.path import join, exists
import numpy as np

emb_dim = 4
checkpoint_dir = 'checkpoint_dir'
meta_name = '0'
saver_dir = join(checkpoint_dir, meta_name)
meta_file = saver_dir + '.meta'
model_file = tf.train.latest_checkpoint(checkpoint_dir)

np.random.seed(123)
test_data = np.random.randn(4, emb_dim)

def read_pb_meta(meta_file):
    meta_graph_def = meta_graph.read_meta_graph_file(meta_file)
    return meta_graph_def

def update_node(graph, src_node_name, tar_node):
    """
    @params:
        graph : tensorflow Graph object
        src_node_name : source node name to be modified
        tar_node : target node
    """
    input = graph.get_tensor_by_name('{}:0'.format(src_node_name))
    for op in input.consumers():
        idx_list = []
        for idx, item in enumerate(op.inputs):
            if src_node_name in item.name:
                idx_list.append(idx)
        for idx in idx_list:
            op._update_input(idx, tar_node)

def modify_graph():
    meta_graph_def = read_pb_meta(meta_file)
    with tf.Graph().as_default() as graph:
        tf.import_graph_def(meta_graph_def.graph_def, name="")
        input_ph = tf.placeholder(tf.float64, [None, emb_dim], name='input')
        update_node(graph, 'IteratorGetNext', input_ph)

    with tf.Session(graph=graph) as sess:
        saver = tf.train.import_meta_graph(meta_file)
        saver.restore(sess, model_file)
        output = sess.run(['output:0'], feed_dict={
    
    
            'input:0': test_data
        })
        print('modify_graph:\n{}'.format(output))


if __name__ == '__main__':
    modify_graph()

Este archivo define la función update_nodepara reemplazar los nodos en el gráfico, la función es la siguiente:

def update_node(graph, src_node_name, tar_node):
    """
    @params:
        graph : tensorflow Graph object
        src_node_name : source node name to be modified
        tar_node : target node
    """
    input = graph.get_tensor_by_name('{}:0'.format(src_node_name))
    for op in input.consumers():
        idx_list = []
        for idx, item in enumerate(op.inputs):
            if src_node_name in item.name:
                idx_list.append(idx)
        for idx in idx_list:
            op._update_input(idx, tar_node)

Entre ellos, src_node_namerepresenta el nombre del nodo que se va a reemplazar, por ejemplo, si desea reemplazarlo IteratorGetNext. Busque el nodo correspondiente grapha input, y luego llame para input.consumers()encontrar el nodo que usa el nodo, y luego reemplace el nodo opa través del entrada opactualizada ( ).Debido al método de reemplazo, op.inputsse requiereop._update_input un índice idx, por lo que se usa idx_listpara registrar el índice del nodo que se reemplazará.

formato de gráfico congelado

checkpointEl formato descrito anteriormente frozen_graphguarda la estructura y los parámetros de la red por separado, mientras que el formato escribirá los parámetros de la red en GraphDef en forma de nodos Const y los guardará en un archivo protobuf unificado, porque protobuf es una serialización de datos multiplataforma y multilenguaje. protocol , por lo que también puede usar C++/Java/Python para cargar el modelo.

Un ejemplo simple de convertir ckpt a frozen_graph se escribe a continuación frozen_graph.py, el código es el siguiente:

#_*_ coding:utf-8 _*_
## frozen_graph.py
import tensorflow as tf
from tensorflow.python.framework import meta_graph
from tensorflow.python.framework import dtypes
from tensorflow.python.tools import optimize_for_inference_lib
import os
from os.path import join, exists
import numpy as np

emb_dim = 4
checkpoint_dir = 'checkpoint_dir'
meta_name = '0'
saver_dir = join(checkpoint_dir, meta_name)
meta_file = saver_dir + '.meta'
model_file = tf.train.latest_checkpoint(checkpoint_dir)

np.random.seed(123)
test_data = np.random.randn(4, emb_dim)

def read_pb_meta(meta_file):
    meta_graph_def = meta_graph.read_meta_graph_file(meta_file)
    return meta_graph_def

def update_node(graph, src_node_name, tar_node):
    """
    @params:
        graph : tensorflow Graph object
        src_node_name : source node name to be modified
        tar_node : target node
    """
    input = graph.get_tensor_by_name('{}:0'.format(src_node_name))
    for op in input.consumers():
        idx_list = []
        for idx, item in enumerate(op.inputs):
            if src_node_name in item.name:
                idx_list.append(idx)
        for idx in idx_list:
            op._update_input(idx, tar_node)

def check_graph_def(graph_def):
    with tf.Graph().as_default() as graph:
        tf.import_graph_def(
            graph_def,
            name=""
        )
        print('===> {}'.format(type(graph)))
        for op in graph.get_operations():
            print(op.name, op.values())  ## 打印网络结构

def write_frozen_graph():
    meta_graph_def = read_pb_meta(meta_file)
    with tf.Graph().as_default() as graph:
        tf.import_graph_def(meta_graph_def.graph_def, name="")
        input_ph = tf.placeholder(tf.float64, [None, emb_dim], name='input')
        update_node(graph, 'IteratorGetNext', input_ph)

    with tf.Session(graph=graph) as sess:
        saver = tf.train.import_meta_graph(meta_file)
        saver.restore(sess, model_file)

        input_node_names = ['input']
        ##placeholder_type_enum = [dtypes.float64.as_datatype_enum]
        placeholder_type_enum = [input_ph.dtype.as_datatype_enum]
        output_node_names = ['output']
        ## 对 graph 进行优化, 把和 inference 无关的节点给删除, 比如 Saver 有关的节点
        graph_def = optimize_for_inference_lib.optimize_for_inference(
            graph.as_graph_def(), input_node_names, output_node_names, placeholder_type_enum
        )
        check_graph_def(graph_def)
        ## 将 ckpt 转换为 frozen_graph, 网络权重和结构写入统一 pb 文件中, 参数以 Const 的形式保存
        frozen_graph = tf.graph_util.convert_variables_to_constants(sess, 
            graph_def, output_node_names)
        out_graph_path = os.path.join('.', "frozen_model.pb")
        with tf.gfile.GFile(out_graph_path, "wb") as f:
            f.write(frozen_graph.SerializeToString())

def read_frozen_graph():
    with tf.Graph().as_default() as graph:
        graph_def = tf.GraphDef()
        with open("frozen_model.pb", 'rb') as f:
            graph_def.ParseFromString(f.read())
            tf.import_graph_def(graph_def, name='')
        
        # print(graph_def)
    
    with tf.Session(graph=graph) as sess:
        output = sess.run(['output:0'], feed_dict={
    
    
            'input:0': test_data
        })
        print('frozen_graph:\n{}'.format(output))   

if __name__ == '__main__':
    write_frozen_graph()
    read_frozen_graph()

Entre write_frozen_graph()ellos llame optimize_for_inference_lib.optimize_for_inferencepara optimizar el nodo Graph, que se presentará en la siguiente sección.Además, llame tf.graph_util.convert_variables_to_constantspara convertir ckpt a frozen_graph, y los parámetros se guardan en forma de Const:

Optimización de gráficos de servicio

Cuando se generó el gráfico_congelado en la sección anterior, se llamó optimize_for_inference_lib.optimize_for_inferencepara optimizar el nodo Gráfico. Esta sección lo explica brevemente. Si imprime el gráfico cargado desde el punto de control antes de llamar a esta función, encontrará que la estructura contiene muchas operaciones que son no se requiere para el servicio en línea, como algoritmos de optimización Adam, guardado de modelos Saver, gradientes gradients, etc. , como se muestra en la figura a continuación:

optimize_for_inference_lib.optimize_for_inferenceUna de las tareas principales de la función es eliminar el Op inútil cuando el gráfico está sirviendo.

Esta función se define en https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/tools/optimize_for_inference_lib.py ,

def optimize_for_inference(input_graph_def, input_node_names, output_node_names,
                           placeholder_type_enum, toco_compatible=False):
  ## ..... 显示核心代码
  optimized_graph_def = strip_unused_lib.strip_unused(
      optimized_graph_def, input_node_names, output_node_names,
      placeholder_type_enum)
  optimized_graph_def = graph_util.remove_training_nodes(
      optimized_graph_def, output_node_names)
  ## .... 
  return optimized_graph_def

Donde strip_unused_lib.strip_unused se define de la siguiente manera:

def strip_unused(input_graph_def, input_node_names, output_node_names,
                 placeholder_type_enum):
  """Removes unused nodes from a GraphDef.
  Args:
    input_graph_def: A graph with nodes we want to prune.
    input_node_names: A list of the nodes we use as inputs.
    output_node_names: A list of the output nodes.
    placeholder_type_enum: The AttrValue enum for the placeholder data type, or
        a list that specifies one value per input node name.
  Returns:
    A `GraphDef` with all unnecessary ops removed.
  Raises:
    ValueError: If any element in `input_node_names` refers to a tensor instead
      of an operation.
    KeyError: If any element in `input_node_names` is not found in the graph.
  """
  for name in input_node_names:
    if ":" in name:
      raise ValueError(f"Name '{
      
      name}' appears to refer to a Tensor, not an "
                       "Operation.")

  # Here we replace the nodes we're going to override as inputs with
  # placeholders so that any unused nodes that are inputs to them are
  # automatically stripped out by extract_sub_graph().
  not_found = {
    
    name for name in input_node_names}
  inputs_replaced_graph_def = graph_pb2.GraphDef()
  for node in input_graph_def.node:
    if node.name in input_node_names:
      not_found.remove(node.name)
      placeholder_node = node_def_pb2.NodeDef()
      placeholder_node.op = "Placeholder"
      placeholder_node.name = node.name
      if isinstance(placeholder_type_enum, list):
        input_node_index = input_node_names.index(node.name)
        placeholder_node.attr["dtype"].CopyFrom(
            attr_value_pb2.AttrValue(type=placeholder_type_enum[
                input_node_index]))
      else:
        placeholder_node.attr["dtype"].CopyFrom(
            attr_value_pb2.AttrValue(type=placeholder_type_enum))
      if "_output_shapes" in node.attr:
        placeholder_node.attr["_output_shapes"].CopyFrom(node.attr[
            "_output_shapes"])
      if "shape" in node.attr:
        placeholder_node.attr["shape"].CopyFrom(node.attr["shape"])
      inputs_replaced_graph_def.node.extend([placeholder_node])
    else:
      inputs_replaced_graph_def.node.extend([copy.deepcopy(node)])

  if not_found:
    raise KeyError(f"The following input nodes were not found: {
      
      not_found}.")

  output_graph_def = graph_util.extract_sub_graph(inputs_replaced_graph_def,
                                                  output_node_names)
  return output_graph_def

El código debe pasarse en graph_def, el nombre del nodo de entrada input_node_namesy el nombre del nodo de salida output_node_names. La gran pieza de código anterior es Placeholderreemplazar el nodo de entrada original con , que es reescribir todo el gráfico. Luego, en el archivo graph_util. extract_sub_graph , use el algoritmo BFS para conservar los nodos de servicio que se necesitan y elimine todos los nodos innecesarios:

def extract_sub_graph(graph_def, dest_nodes):
  """Extract the subgraph that can reach any of the nodes in 'dest_nodes'.
  Args:
    graph_def: A graph_pb2.GraphDef proto.
    dest_nodes: An iterable of strings specifying the destination node names.
  Returns:
    The GraphDef of the sub-graph.
  Raises:
    TypeError: If 'graph_def' is not a graph_pb2.GraphDef proto.
  """

 ## ... BFS 遍历 Serving 时用到的节点

  nodes_to_keep = _bfs_for_reachable_nodes(dest_nodes, name_to_input_name)

  nodes_to_keep_list = sorted(
      list(nodes_to_keep), key=lambda n: name_to_seq_num[n])
  # Now construct the output GraphDef
  out = graph_pb2.GraphDef()
  for n in nodes_to_keep_list:
    out.node.extend([copy.deepcopy(name_to_node[n])])
  out.library.CopyFrom(graph_def.library)
  out.versions.CopyFrom(graph_def.versions)
  
  return out

La función BFS se define de la siguiente manera:

def _node_name(n):
  if n.startswith("^"):
    return n[1:]
  else:
    return n.split(":")[0]

def _extract_graph_summary(graph_def):
  """Extracts useful information from the graph and returns them."""
  name_to_input_name = {
    
    }  # Keyed by the dest node name.
  name_to_node = {
    
    }  # Keyed by node name.

  # Keeps track of node sequences. It is important to still output the
  # operations in the original order.
  name_to_seq_num = {
    
    }  # Keyed by node name.
  seq = 0
  for node in graph_def.node:
    n = _node_name(node.name)
    name_to_node[n] = node
    name_to_input_name[n] = [_node_name(x) for x in node.input]
    ### ....
    name_to_seq_num[n] = seq
    seq += 1
  return name_to_input_name, name_to_node, name_to_seq_num

def _bfs_for_reachable_nodes(target_nodes, name_to_input_name):
  """Breadth first search for reachable nodes from target nodes."""
  nodes_to_keep = set()
  # Breadth first search to find all the nodes that we should keep.
  next_to_visit = list(target_nodes)
  while next_to_visit:
    node = next_to_visit[0]
    del next_to_visit[0]
    if node in nodes_to_keep:
      # Already visited this node.
      continue
    nodes_to_keep.add(node)
    if node in name_to_input_name:
      next_to_visit += name_to_input_name[node]
  return nodes_to_keep

La razón por la que estos fragmentos de código se extraen por separado es que puede graph_defdepurar . optimize_for_inference_lib.optimize_for_inferenceDespués del procesamiento de, el gráfico es más conciso y ligero. Imprima el Op en él para obtener:

Se puede ver que se usará en el entrenamiento Adam, Savery cuando se eliminan todos los nodos, todo el gráfico se vuelve extremadamente limpio y ordenado.

Resumir

Escribir un artículo es ir a por todas, luego bajar, tres veces, y luego ir a por todas.
Voy a jugar.

Supongo que te gusta

Origin blog.csdn.net/Eric_1993/article/details/126197197
Recomendado
Clasificación