Introducción a los dos tipos de archivo del modelo de entrenamiento previo de Tensorflow ckpt y pb

Usamos  SSD (detección de fotogramas múltiples de disparo único) en el lado móvil para identificar objetos y comprender Graph en vehículos no tripulados de Tensorflow. Estamos familiarizados con el gráfico de cálculo de Graph y la introducción de conocimientos relevantes sobre funciones en Tensorflow 2.0 (el recomendado reemplazo de Graph en la versión 1.0). El uso de  esta función tf es para comprender los roles respectivos del flujo de control y el gráfico de cálculo. No importa qué método se use, en el aprendizaje profundo, lo más crítico es usar el entrenamiento previo modelo.

Al igual que otros marcos que necesitan cargar el modelo preentrenado, aquí también es necesario importar el archivo de parámetros de peso, es decir, el código que usa SSD para identificar objetos, de la siguiente manera:

MODEL_NAME = 'ssdlite_mobilenet_v2_coco_2018_05_09'
PATH_TO_CKPT = MODEL_NAME + '/frozen_inference_graph.pb' 
PATH_TO_LABELS = os.path.join('data', 'mscoco_label_map.pbtxt')

Puede ver que aquí se cargan dos archivos: uno es un archivo pb y el otro es un archivo de etiqueta pbtxt.Para el archivo pb , podemos ver el nombre del archivo del directorio donde se encuentra ssdlite_mobilenet_v2_coco_2018_05_09 , y también podemos saber que pertenece al modelo de preentrenamiento SSD ligero en el lado móvil , y se utiliza el conjunto de datos de COCO2018 .
El siguiente paso es centrarse en la introducción de operaciones de archivo de modelo previas al entrenamiento.

1. Guarde el modelo ckpt

Veamos un ejemplo primero. Después del entrenamiento, guardamos el modelo. El propósito de esto es que no necesita volver a entrenar cada vez. Puede cargar directamente los parámetros para la inferencia, y también se puede trasplantar fácilmente a otros programas.

import tensorflow.compat.v1 as tf 
tf.disable_eager_execution()

v1 = tf.Variable(tf.constant([[11],[22]]),name='v1')
v2 = tf.Variable(tf.constant([[33],[44]]),name='v2')
result = v1 * v2

saver = tf.train.Saver()
with tf.Session() as sess:
    # 初始化所有变量
    #tf.global_variables_initializer().run()
    sess.run(tf.global_variables_initializer())
    print(sess.run(v1))
    print(sess.run(v2))
    print(sess.run(result))
    #这个扩展名".ckpt"可以忽略
    saver.save(sess,'model/model.ckpt')
'''
[[11]
 [22]]
[[33]
 [44]]
[[363]
 [968]]
'''

Como se muestra abajo:

De esta manera, el modelo pre-entrenado se guarda a través del método save.Un archivo de modelo .ckpt estándar contiene los siguientes archivos:

checkpoint: archivo de texto, por lo que se puede abrir directamente, el contenido es el siguiente:

model_checkpoint_path: "modelo.ckpt" 和all_model_checkpoint_paths: "modelo.ckpt"

model.ckpt.data-00000-of-00001: guarde el valor de la variable
model.ckpt.index: guarde el nombre de la variable, que se puede considerar como la forma de clave-valor junto con el
modelo anterior.ckpt. meta: guarda la estructura del gráfico de cálculo

Veamos las variables en el modelo:

read_ckpt =tf.train.NewCheckpointReader("model/model.ckpt")
print(read_ckpt.debug_string().decode("utf-8"))
print(read_ckpt.get_variable_to_dtype_map())
print(read_ckpt.get_variable_to_shape_map())
'''
v1 (DT_INT32) [2,1]
v2 (DT_INT32) [2,1]
{'v1': tf.int32, 'v2': tf.int32}
{'v1': [2, 1], 'v2': [2, 1]}
'''

Use tf.train.NewCheckpointReader para leer el archivo del modelo, mostrando el nombre de la variable, el tipo de datos y la forma. Los dos últimos métodos obtienen el tipo de variable y el tipo de diccionario de la forma, respectivamente.

2. Cargue el modelo ckpt

 Después de guardar el modelo anterior, carguemos el modelo para probarlo:

import tensorflow.compat.v1 as tf
tf.disable_eager_execution()

saver = tf.train.import_meta_graph('model/model.ckpt.meta')
with tf.Session() as sess:
    saver.restore(sess,'model/model.ckpt')
    out = tf.get_default_graph().get_tensor_by_name('mul:0')
    print(sess.run(out))
'''
INFO:tensorflow:Restoring parameters from model/model.ckpt
[[363]
 [968]]
'''

Se puede ver que la estructura del gráfico de cálculo se carga primero a través de tf.train.import_meta_graph . Luego use el método de restauración para restaurar el archivo del modelo.
Lo mismo también es necesario para deshabilitar temporalmente este modo de ejecución instantánea: método tf.disable_eager_execution()
para obtener el nombre del nodo, porque esta es una operación de multiplicación, por lo que es mul , si es suma, es suma , por supuesto, si son otras operaciones, si desea saber el nombre Puede imprimir y ver directamente, por ejemplo, antes de guardar el modelo, puede ver:

print(result)
Tensor("mul:0", shape=(2, 1), dtype=int32)

3. Guarda el modelo pb

Como puede ver en la imagen de arriba, el modelo de preentrenamiento guardado tiene 4 archivos, cada uno de los cuales está separado, lo cual es relativamente desordenado. Podemos escribir variables y valores de peso, etc., en un solo archivo, que es el único. en el artículo anterior El archivo pb . Echemos un vistazo a cómo guardarlo como un archivo pb

import tensorflow.compat.v1 as tf
from tensorflow.python.framework import graph_util
tf.disable_eager_execution()

v1 = tf.Variable(tf.constant([[1],[2]]),name='v1')
v2 = tf.Variable(tf.constant([[3],[4]]),name='v2')
result = v1 * v2

init_op = tf.global_variables_initializer()
with tf.Session() as sess:
    sess.run(init_op)
    graph_def = tf.get_default_graph().as_graph_def()
    #print(graph_def)
    output_graph_def = graph_util.convert_variables_to_constants(sess,graph_def,['mul'])
    with tf.gfile.GFile('model/newmodel.pb','wb') as f:
        f.write(output_graph_def.SerializeToString())

Aquí podemos ver que todavía hay una diferencia con guardar el archivo ckpt Primero, obtenga el gráfico de cálculo, luego use graph_util.convert_variables_to_constants para convertir la variable en una constante (para que la arquitectura de la red y los pesos se guarden en un archivo), y finalmente a través de SerializeToString( ) para convertirlo en un flujo de bytes y escribirlo en el archivo.

4. Cargue el modelo pb

Después de guardar el archivo del modelo pb anterior, cargamos el modelo para probar:

with tf.Session() as sess:
    with tf.io.gfile.GFile('model/newmodel.pb','rb') as f:
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(f.read())
    result = tf.import_graph_def(graph_def,return_elements=['mul:0'])
    print(sess.run(result))
'''
[array([[3],
       [8]])]
'''

Lo mismo es usar el método tf.io.gfile.GFile , donde se especifica el parámetro rb para leer el archivo del modelo, luego leerlo y analizarlo en un gráfico de cálculo, y finalmente usar tf.import_graph_def para importarlo en el gráfico predeterminado , donde la sesión Si el gráfico de cálculo no se especifica en los parámetros, es el gráfico predeterminado.
Además, no importa si es guardar o leer, se lleva a cabo en la sesión tf.Session.Si necesita calcular el valor del nodo, puede usar la función sess.run().

5. marcador de posición tf.placeholder

Imprimimos el gráfico de cálculo en el artículo anterior y podemos ver que, de hecho, hay muchas operaciones en él , y cada minilote será una operación , que también es una sobrecarga de recursos, por lo que usamos el marcador de posición tf.placeholder para manejar la operación repetida , por ejemplo, cada vez que el minilote se pasa a x = tf.placeholder(tf.float32,[None,32]) , la x pasada la próxima vez reemplazará directamente a la x anterior y no generará una nueva operación , que ahorra gastos generales.

import tensorflow.compat.v1 as tf
import numpy as np
tf.disable_eager_execution()

a = tf.placeholder(tf.float32)
b = tf.placeholder(tf.float32)
output = tf.multiply(a, b)
 
with tf.Session() as sess:
    print(sess.run(output, feed_dict = {a:[12.], b: [3.5]}))#[42.]

 Use marcadores de posición y luego alimente los datos usando feed_dict .

6. Manejo de errores

RuntimeError: cuando la ejecución entusiasta está habilitada, `var_list` debe especificar una lista o dictado de variables para guardar

Error de tiempo de ejecución: 'var_list' debe especificar una lista o diccionario de variables para guardar cuando la ejecución inmediata está habilitada

Podemos resolverlo deshabilitando la ejecución instantánea. Esta ejecución instantánea significa que Tensorflow ejecutará cada operación inmediatamente en lugar de crear primero un gráfico de cálculo. Esta ejecución entusiasta está habilitada de forma predeterminada desde Tensorflow 2.0. La presentaré aquí por conveniencia Puntos de conocimiento, utilizará 1.0 y se deshabilitará temporalmente, de modo que el gráfico de cálculo se construya en tiempo de ejecución. tf.disable_eager_execution()

RuntimeError: el gráfico de sesión está vacío. Agregue operaciones al gráfico antes de llamar a run().
RuntimeError: el gráfico de sesión está vacío. Agrega operaciones al gráfico antes de llamar a run().

Aquí también está relacionado con la ejecución instantánea, la misma que deshabilitamos temporalmente: tf.disable_eager_execution()

Además, de tensorflow.python.framework.graph_util_impl se quitará y eliminará después de la versión 2.0. El código anterior también es compatible con la versión 1.0 con fines de demostración.

RuntimeError: tf.placeholder() no es compatible con la ejecución entusiasta
Error de tiempo de ejecución: el marcador de posición no es compatible con la ejecución entusiasta

Lo mismo deshabilita temporalmente la ejecución ansiosa: tf.disable_eager_execution()

Supongo que te gusta

Origin blog.csdn.net/weixin_41896770/article/details/131990678
Recomendado
Clasificación