Guardado del modelo de Tensorflow, modificación de nodos y optimización de gráficos de servicio
Directorio de artículos
- Guardado del modelo de Tensorflow, modificación de nodos y optimización de gráficos de servicio
Prólogo (no relacionado con el texto, puede ignorarse)
En un futuro cercano, tengo la intención de resumir algunos conocimientos básicos de Tensorflow para facilitar la referencia. La motivación para escribir este artículo es considerar un pequeño problema: a menudo usamos tf.data
una serie de Los gráficos suelen ser nodos iteradores (como la llamada tf.data.make_one_shot_iterator
y el get_next()
método ), pero cuando estaba en servicio, estaba pensando en cómo lidiar con los nodos de entrada y cómo tf.placeholder
agregar nuevos al gráfico de servicio.
Un método es reescribir el gráfico de servicio, actualizar el nodo de entrada tf.placeholder
y luego ingresarlo en el modelo para generar un nuevo gráfico; pero espero que haya un método más conciso, como si es posible reemplazar directamente el nodo de entrada del iterador. tf.placeholder
with Incluso si no sé cómo está escrito el código del modelo, puedo construir un gráfico de servicio Bajo la guía de esta pregunta, tengo un poco de comprensión profunda de conceptos como guardar y cargar modelos TF y Graph /Metagráfico.
descripción general
Este artículo presenta el método para guardar parte del modelo de Tensorflow, incluido principalmente checkpoint
el formato , frozen_graph
el formato ( SavedModel
el formato se omite temporalmente), a través del ejemplo de código para comprender el método para guardar el modelo, la optimización del gráfico de servicio y la modificación. y actualización de los nodos en el gráfico de Servicio.
código de dirección
El código de este artículo se probó con éxito en el Python 3.5.2
entorno Tensorflow 1.15.0
.
Todo el código de este artículo se puede descargar desde https://github.com/axzml/BlogShare/tree/master/Tensorflow/GraphDef .
Amplia publicidad
Puede buscar "Jenny's Algorithm Road" o "world4458" en WeChat, seguir mi cuenta pública de WeChat y obtener las últimas actualizaciones de artículos técnicos originales a tiempo.
Además, puede echar un vistazo a la columna de Zhihu PoorMemory-Machine Learning , y también se publicarán artículos futuros en la columna de Zhihu.
formato de punto de control
código de entrenamiento y guardar ckpt
Escribí un código de entrenamiento simple ( train.py
) de la siguiente manera, completo con cinco órganos internos, que define tres funciones principales:
data_generator()
: Genere datos falsos para participar en el entrenamiento del modelo.model()
: define una red neuronal simpletrain()
: Defina el código de entrenamiento, llame para guardar el modelotf.train.Saver()
en forma de punto de control
# _*_ coding:utf-8 _*_
## train.py
import tensorflow as tf
import os
import numpy as np
from os.path import join, exists
batch_size = 2
steps = 10
epochs = 1
emb_dim = 4
sample_num = epochs * steps * batch_size
checkpoint_dir = 'checkpoint_dir'
meta_name = '0'
saver_dir = join(checkpoint_dir, meta_name)
def data_generator():
"""产生 Fake 训练数据"""
dataset = tf.data.Dataset.from_tensor_slices((np.random.randn(sample_num, emb_dim),\
np.random.randn(sample_num)))
dataset = dataset.repeat(epochs).batch(batch_size)
iterator = tf.data.make_one_shot_iterator(dataset)
feature, label = iterator.get_next()
return feature, label
def model(feature, params=[10, 5, 1]):
"""定义模型, 3层DNN"""
fc1 = tf.layers.dense(feature, params[0], activation=tf.nn.relu, name='fc1')
fc2 = tf.layers.dense(fc1, params[1], activation=tf.nn.relu, name='fc2')
fc3 = tf.layers.dense(fc2, params[2], activation=tf.nn.sigmoid, name='fc3')
out = tf.identity(fc3, name='output')
return out
def train():
feature, label = data_generator()
output = model(feature)
loss = tf.reduce_mean(tf.square(output - label))
train_op = tf.train.AdamOptimizer(learning_rate=0.1, name='Adam').minimize(loss)
saver = tf.train.Saver()
if exists(checkpoint_dir):
os.system('rm -rf {}'.format(checkpoint_dir))
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
try:
local_step = 0
save_freq = 2
while True:
local_step += 1
_, loss_val = sess.run([train_op, loss])
if local_step % save_freq == 0:
saver.save(sess, saver_dir)
print('loss: {:.4f}'.format(loss_val))
except tf.errors.OutOfRangeError:
print("train end!")
if __name__ == '__main__':
train()
La python train.py
ejecución generará checkpoint_dir
un directorio bajo el directorio actual, que consta de lo siguiente:
checkpoint_dir/
|-- 0.data-00000-of-00001 ## 记录了网络参数值
|-- 0.index ## 记录了网络参数名
|-- 0.meta ## 保存 MetaGraphDef, 该文件以 pb 格式记录了网络结构
`-- checkpoint ## 该文件记录了最新的 ckpt
Cargue ckpt y verifique la estructura del gráfico
checkpoint
El modelo en el formato debe cargarse en el marco de Tensorflow. Por ejemplo, escribe eval.py
en inferencia, el código es el siguiente:
#_*_ coding:utf-8 _*_
## eval.py
import tensorflow as tf
import os
from os.path import join, exists
import numpy as np
emb_dim = 4
checkpoint_dir = 'checkpoint_dir'
meta_name = '0'
saver_dir = join(checkpoint_dir, meta_name)
meta_file = saver_dir + '.meta'
model_file = tf.train.latest_checkpoint(checkpoint_dir)
np.random.seed(123)
test_data = np.random.randn(4, emb_dim) ## 生成测试数据
def eval_graph():
with tf.Session() as sess:
saver = tf.train.import_meta_graph(meta_file)
saver.restore(sess, model_file)
output = sess.run(['output:0'], feed_dict={
'IteratorGetNext:0': test_data
})
print('eval_graph:\n{}'.format(output))
if __name__ == '__main__':
eval_graph()
En el código anterior, observe que los nombres de los nodos de entrada y salida son output
y IteratorGetNext
Para el nodo de salida, train.py
ya que model()
se usa en la función de
out = tf.identity(fc3, name='output')
Cambie el nombre del nodo de salida a output
, de modo que el nombre del nodo de salida sea muy fácil de determinar. Pero el nombre del nodo de entrada no es muy fácil de determinar, porque la tf.data
API se usa para pasar los datos durante el entrenamiento y el nodo de entrada no tiene un nombre explícito Sin embargo, debido a que Al guardar el modelo, la estructura de la red se almacenó en 0.meta
el archivo , por lo que los nodos de entrada de la red se pueden ver analizando el archivo, el método específico es el siguiente:
#_*_ coding:utf-8 _*_
## check_graph.py
import tensorflow as tf
from tensorflow.python.framework import meta_graph
from tensorflow.core.protobuf.meta_graph_pb2 import MetaGraphDef
from google.protobuf import text_format
import os
from os.path import join, exists
import numpy as np
checkpoint_dir = 'checkpoint_dir'
meta_name = '0'
saver_dir = join(checkpoint_dir, meta_name)
meta_file = saver_dir + '.meta'
model_file = tf.train.latest_checkpoint(checkpoint_dir)
def read_pb_meta(meta_file):
"""读取 pb 格式的 meta 文件"""
meta_graph_def = meta_graph.read_meta_graph_file(meta_file)
return meta_graph_def
def read_txt_meta(txt_meta_file):
"""读取文本格式的 meta 文件"""
meta_graph = MetaGraphDef()
with open(txt_meta_file, 'rb') as f:
text_format.Merge(f.read(), meta_graph)
return meta_graph
def read_pb_graph(graph_file):
"""读取 pb 格式的 graph_def 文件"""
try:
with tf.gfile.GFile(graph_file, 'rb') as pb:
graph_def = tf.GraphDef()
graph_def.ParseFromString(pb.read())
except IOError as e:
raise Exception("Parse '{}' Failed!".format(graph_file))
return graph_def
def check_graph_def(graph_def):
"""检查 graph_def 中的各节点"""
with tf.Graph().as_default() as graph:
tf.import_graph_def(
graph_def,
name=""
)
print('===> {}'.format(type(graph)))
for op in graph.get_operations():
print(op.name, op.values()) ## 打印网络结构
def check_graph(graph_file):
"""检查 pb 格式的 graph_def 文件中的各节点"""
graph_def = read_pb_graph(graph_file)
check_graph_def(graph_def)
if __name__ == '__main__':
check_graph_def(read_pb_meta(meta_file).graph_def)
El resultado de salida se muestra en la siguiente figura. Se puede encontrar que el nodo fc1/kernel
más cercano es IteratorGetNext
, por lo que básicamente se puede confirmar que el nombre del nodo de entrada es ese.
modificación de nodo
Ahora volviendo a la pregunta mencionada en el "Prefacio", si quiero usar el tf.placeholder
nodo como el nodo de entrada del Gráfico en lugar de usarlo, IteratorGetNext
¿cómo debo lograrlo? Por un lado, puedo reescribir el Tensorflow Grafique y utilícelo tf.placeholder
como entrada; por otro lado, de hecho, puede considerar reemplazar IteratorGetNet
el nodo con un nodo personalizado. Este paso se refiere a la publicación del blog Cómo modificar el gráfico TF después de construirlo . El método específico es el siguiente , el código está infer.py
en :
#_*_ coding:utf-8 _*_
## infer.py
import tensorflow as tf
from tensorflow.python.framework import meta_graph
import os
from os.path import join, exists
import numpy as np
emb_dim = 4
checkpoint_dir = 'checkpoint_dir'
meta_name = '0'
saver_dir = join(checkpoint_dir, meta_name)
meta_file = saver_dir + '.meta'
model_file = tf.train.latest_checkpoint(checkpoint_dir)
np.random.seed(123)
test_data = np.random.randn(4, emb_dim)
def read_pb_meta(meta_file):
meta_graph_def = meta_graph.read_meta_graph_file(meta_file)
return meta_graph_def
def update_node(graph, src_node_name, tar_node):
"""
@params:
graph : tensorflow Graph object
src_node_name : source node name to be modified
tar_node : target node
"""
input = graph.get_tensor_by_name('{}:0'.format(src_node_name))
for op in input.consumers():
idx_list = []
for idx, item in enumerate(op.inputs):
if src_node_name in item.name:
idx_list.append(idx)
for idx in idx_list:
op._update_input(idx, tar_node)
def modify_graph():
meta_graph_def = read_pb_meta(meta_file)
with tf.Graph().as_default() as graph:
tf.import_graph_def(meta_graph_def.graph_def, name="")
input_ph = tf.placeholder(tf.float64, [None, emb_dim], name='input')
update_node(graph, 'IteratorGetNext', input_ph)
with tf.Session(graph=graph) as sess:
saver = tf.train.import_meta_graph(meta_file)
saver.restore(sess, model_file)
output = sess.run(['output:0'], feed_dict={
'input:0': test_data
})
print('modify_graph:\n{}'.format(output))
if __name__ == '__main__':
modify_graph()
Este archivo define la función update_node
para reemplazar los nodos en el gráfico, la función es la siguiente:
def update_node(graph, src_node_name, tar_node):
"""
@params:
graph : tensorflow Graph object
src_node_name : source node name to be modified
tar_node : target node
"""
input = graph.get_tensor_by_name('{}:0'.format(src_node_name))
for op in input.consumers():
idx_list = []
for idx, item in enumerate(op.inputs):
if src_node_name in item.name:
idx_list.append(idx)
for idx in idx_list:
op._update_input(idx, tar_node)
Entre ellos, src_node_name
representa el nombre del nodo que se va a reemplazar, por ejemplo, si desea reemplazarlo IteratorGetNext
. Busque el nodo correspondiente graph
a input
, y luego llame para input.consumers()
encontrar el nodo que usa el nodo, y luego reemplace el nodo op
a través del entrada op
actualizada ( ).Debido al método de reemplazo, op.inputs
se requiereop._update_input
un índice idx
, por lo que se usa idx_list
para registrar el índice del nodo que se reemplazará.
formato de gráfico congelado
checkpoint
El formato descrito anteriormente frozen_graph
guarda la estructura y los parámetros de la red por separado, mientras que el formato escribirá los parámetros de la red en GraphDef en forma de nodos Const y los guardará en un archivo protobuf unificado, porque protobuf es una serialización de datos multiplataforma y multilenguaje. protocol , por lo que también puede usar C++/Java/Python para cargar el modelo.
Un ejemplo simple de convertir ckpt a frozen_graph se escribe a continuación frozen_graph.py
, el código es el siguiente:
#_*_ coding:utf-8 _*_
## frozen_graph.py
import tensorflow as tf
from tensorflow.python.framework import meta_graph
from tensorflow.python.framework import dtypes
from tensorflow.python.tools import optimize_for_inference_lib
import os
from os.path import join, exists
import numpy as np
emb_dim = 4
checkpoint_dir = 'checkpoint_dir'
meta_name = '0'
saver_dir = join(checkpoint_dir, meta_name)
meta_file = saver_dir + '.meta'
model_file = tf.train.latest_checkpoint(checkpoint_dir)
np.random.seed(123)
test_data = np.random.randn(4, emb_dim)
def read_pb_meta(meta_file):
meta_graph_def = meta_graph.read_meta_graph_file(meta_file)
return meta_graph_def
def update_node(graph, src_node_name, tar_node):
"""
@params:
graph : tensorflow Graph object
src_node_name : source node name to be modified
tar_node : target node
"""
input = graph.get_tensor_by_name('{}:0'.format(src_node_name))
for op in input.consumers():
idx_list = []
for idx, item in enumerate(op.inputs):
if src_node_name in item.name:
idx_list.append(idx)
for idx in idx_list:
op._update_input(idx, tar_node)
def check_graph_def(graph_def):
with tf.Graph().as_default() as graph:
tf.import_graph_def(
graph_def,
name=""
)
print('===> {}'.format(type(graph)))
for op in graph.get_operations():
print(op.name, op.values()) ## 打印网络结构
def write_frozen_graph():
meta_graph_def = read_pb_meta(meta_file)
with tf.Graph().as_default() as graph:
tf.import_graph_def(meta_graph_def.graph_def, name="")
input_ph = tf.placeholder(tf.float64, [None, emb_dim], name='input')
update_node(graph, 'IteratorGetNext', input_ph)
with tf.Session(graph=graph) as sess:
saver = tf.train.import_meta_graph(meta_file)
saver.restore(sess, model_file)
input_node_names = ['input']
##placeholder_type_enum = [dtypes.float64.as_datatype_enum]
placeholder_type_enum = [input_ph.dtype.as_datatype_enum]
output_node_names = ['output']
## 对 graph 进行优化, 把和 inference 无关的节点给删除, 比如 Saver 有关的节点
graph_def = optimize_for_inference_lib.optimize_for_inference(
graph.as_graph_def(), input_node_names, output_node_names, placeholder_type_enum
)
check_graph_def(graph_def)
## 将 ckpt 转换为 frozen_graph, 网络权重和结构写入统一 pb 文件中, 参数以 Const 的形式保存
frozen_graph = tf.graph_util.convert_variables_to_constants(sess,
graph_def, output_node_names)
out_graph_path = os.path.join('.', "frozen_model.pb")
with tf.gfile.GFile(out_graph_path, "wb") as f:
f.write(frozen_graph.SerializeToString())
def read_frozen_graph():
with tf.Graph().as_default() as graph:
graph_def = tf.GraphDef()
with open("frozen_model.pb", 'rb') as f:
graph_def.ParseFromString(f.read())
tf.import_graph_def(graph_def, name='')
# print(graph_def)
with tf.Session(graph=graph) as sess:
output = sess.run(['output:0'], feed_dict={
'input:0': test_data
})
print('frozen_graph:\n{}'.format(output))
if __name__ == '__main__':
write_frozen_graph()
read_frozen_graph()
Entre write_frozen_graph()
ellos llame optimize_for_inference_lib.optimize_for_inference
para optimizar el nodo Graph, que se presentará en la siguiente sección.Además, llame tf.graph_util.convert_variables_to_constants
para convertir ckpt a frozen_graph, y los parámetros se guardan en forma de Const:
Optimización de gráficos de servicio
Cuando se generó el gráfico_congelado en la sección anterior, se llamó optimize_for_inference_lib.optimize_for_inference
para optimizar el nodo Gráfico. Esta sección lo explica brevemente. Si imprime el gráfico cargado desde el punto de control antes de llamar a esta función, encontrará que la estructura contiene muchas operaciones que son no se requiere para el servicio en línea, como algoritmos de optimización Adam
, guardado de modelos Saver
, gradientes gradients
, etc. , como se muestra en la figura a continuación:
optimize_for_inference_lib.optimize_for_inference
Una de las tareas principales de la función es eliminar el Op inútil cuando el gráfico está sirviendo.
Esta función se define en https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/tools/optimize_for_inference_lib.py ,
def optimize_for_inference(input_graph_def, input_node_names, output_node_names,
placeholder_type_enum, toco_compatible=False):
## ..... 显示核心代码
optimized_graph_def = strip_unused_lib.strip_unused(
optimized_graph_def, input_node_names, output_node_names,
placeholder_type_enum)
optimized_graph_def = graph_util.remove_training_nodes(
optimized_graph_def, output_node_names)
## ....
return optimized_graph_def
Donde strip_unused_lib.strip_unused se define de la siguiente manera:
def strip_unused(input_graph_def, input_node_names, output_node_names,
placeholder_type_enum):
"""Removes unused nodes from a GraphDef.
Args:
input_graph_def: A graph with nodes we want to prune.
input_node_names: A list of the nodes we use as inputs.
output_node_names: A list of the output nodes.
placeholder_type_enum: The AttrValue enum for the placeholder data type, or
a list that specifies one value per input node name.
Returns:
A `GraphDef` with all unnecessary ops removed.
Raises:
ValueError: If any element in `input_node_names` refers to a tensor instead
of an operation.
KeyError: If any element in `input_node_names` is not found in the graph.
"""
for name in input_node_names:
if ":" in name:
raise ValueError(f"Name '{
name}' appears to refer to a Tensor, not an "
"Operation.")
# Here we replace the nodes we're going to override as inputs with
# placeholders so that any unused nodes that are inputs to them are
# automatically stripped out by extract_sub_graph().
not_found = {
name for name in input_node_names}
inputs_replaced_graph_def = graph_pb2.GraphDef()
for node in input_graph_def.node:
if node.name in input_node_names:
not_found.remove(node.name)
placeholder_node = node_def_pb2.NodeDef()
placeholder_node.op = "Placeholder"
placeholder_node.name = node.name
if isinstance(placeholder_type_enum, list):
input_node_index = input_node_names.index(node.name)
placeholder_node.attr["dtype"].CopyFrom(
attr_value_pb2.AttrValue(type=placeholder_type_enum[
input_node_index]))
else:
placeholder_node.attr["dtype"].CopyFrom(
attr_value_pb2.AttrValue(type=placeholder_type_enum))
if "_output_shapes" in node.attr:
placeholder_node.attr["_output_shapes"].CopyFrom(node.attr[
"_output_shapes"])
if "shape" in node.attr:
placeholder_node.attr["shape"].CopyFrom(node.attr["shape"])
inputs_replaced_graph_def.node.extend([placeholder_node])
else:
inputs_replaced_graph_def.node.extend([copy.deepcopy(node)])
if not_found:
raise KeyError(f"The following input nodes were not found: {
not_found}.")
output_graph_def = graph_util.extract_sub_graph(inputs_replaced_graph_def,
output_node_names)
return output_graph_def
El código debe pasarse en graph_def
, el nombre del nodo de entrada input_node_names
y el nombre del nodo de salida output_node_names
. La gran pieza de código anterior es Placeholder
reemplazar el nodo de entrada original con , que es reescribir todo el gráfico. Luego, en el archivo graph_util. extract_sub_graph , use el algoritmo BFS para conservar los nodos de servicio que se necesitan y elimine todos los nodos innecesarios:
def extract_sub_graph(graph_def, dest_nodes):
"""Extract the subgraph that can reach any of the nodes in 'dest_nodes'.
Args:
graph_def: A graph_pb2.GraphDef proto.
dest_nodes: An iterable of strings specifying the destination node names.
Returns:
The GraphDef of the sub-graph.
Raises:
TypeError: If 'graph_def' is not a graph_pb2.GraphDef proto.
"""
## ... BFS 遍历 Serving 时用到的节点
nodes_to_keep = _bfs_for_reachable_nodes(dest_nodes, name_to_input_name)
nodes_to_keep_list = sorted(
list(nodes_to_keep), key=lambda n: name_to_seq_num[n])
# Now construct the output GraphDef
out = graph_pb2.GraphDef()
for n in nodes_to_keep_list:
out.node.extend([copy.deepcopy(name_to_node[n])])
out.library.CopyFrom(graph_def.library)
out.versions.CopyFrom(graph_def.versions)
return out
La función BFS se define de la siguiente manera:
def _node_name(n):
if n.startswith("^"):
return n[1:]
else:
return n.split(":")[0]
def _extract_graph_summary(graph_def):
"""Extracts useful information from the graph and returns them."""
name_to_input_name = {
} # Keyed by the dest node name.
name_to_node = {
} # Keyed by node name.
# Keeps track of node sequences. It is important to still output the
# operations in the original order.
name_to_seq_num = {
} # Keyed by node name.
seq = 0
for node in graph_def.node:
n = _node_name(node.name)
name_to_node[n] = node
name_to_input_name[n] = [_node_name(x) for x in node.input]
### ....
name_to_seq_num[n] = seq
seq += 1
return name_to_input_name, name_to_node, name_to_seq_num
def _bfs_for_reachable_nodes(target_nodes, name_to_input_name):
"""Breadth first search for reachable nodes from target nodes."""
nodes_to_keep = set()
# Breadth first search to find all the nodes that we should keep.
next_to_visit = list(target_nodes)
while next_to_visit:
node = next_to_visit[0]
del next_to_visit[0]
if node in nodes_to_keep:
# Already visited this node.
continue
nodes_to_keep.add(node)
if node in name_to_input_name:
next_to_visit += name_to_input_name[node]
return nodes_to_keep
La razón por la que estos fragmentos de código se extraen por separado es que puede graph_def
depurar . optimize_for_inference_lib.optimize_for_inference
Después del procesamiento de, el gráfico es más conciso y ligero. Imprima el Op en él para obtener:
Se puede ver que se usará en el entrenamiento Adam
, Saver
y cuando se eliminan todos los nodos, todo el gráfico se vuelve extremadamente limpio y ordenado.
Resumir
Escribir un artículo es ir a por todas, luego bajar, tres veces, y luego ir a por todas.
Voy a jugar.