tensorflow 读取两种格式的模型并进行预测

tensorflow 读取两种格式的模型并进行预测

1. 模型保存

1.1 checkpoint 模型

如图所示,
.meta – 保存图结构,即神经网络的网络结构
.data – 保存数据文件,即网络的权值,偏置,操作等等
.index – 是一个不可变得字符串表,每一个键都是张量的名称,它的值是一个序列化的BundleEntryProto。 每个BundleEntryProto描述张量的元数据:“数据”文件中的哪个文件包含张量的内容,该文件的偏移量,校验和,一些辅助数据等等。
checkpoint – 文本文件,里面记录了保存的最新的checkpoint文件以及其它checkpoint文件列表。在inference时,可以通过修改这个文件,指定使用哪个model.

保存模型:

saver = tf.train.Saver()
saver.save(sess, model_path)

  
  
  • 1
  • 2

其中model_path是模型保存路径。

1.2 frozen_graph模型

在工程中,我们往往需要将模型和权重固化,便于发布和预测。
使用tensorFlow官方提供的freeze_graph.py工具来保存相应模型。(代码中把freeze_graph.py文件放在commom.utils.tf路径下导入)

freeze_graph.py先加载模型文件,从checkpoint文件读取权重数据初始化到模型里的权重变量,再将权重变量转换成权重常量,然后再通过指定的输出节点将没用于输出推理的Op节点从图中剥离掉,再重新保存到指定的文件里(用write_graphdef或Saver)。

from tensorflow.core.protobuf import saver_pb2
from common.utils.tf import freeze_graph
# save model graph
tf.train.write_graph(
    sess.graph.as_graph_def(),
    os.path.join(model_path),
    GRAPH_PB_NAME,
    as_text=False)
# generate frozen graph
freeze_graph.freeze_graph(
    input_graph=os.path.join(model_path, GRAPH_PB_NAME),
    input_saver=False,
    input_binary=True,
    input_checkpoint=os.path.join(model_path, CHECKPOINT_PREFIX),
    output_node_names="viterbi_sequence,intent_prediction,intent_probs",
    restore_op_name=None,
    filename_tensor_name=None,
    output_graph=os.path.join(model_path, FROZEN_GRAPH_PB_NAME),
    clear_devices=False,
    initializer_nodes="",
    variable_names_whitelist="",
    variable_names_blacklist="",
    input_meta_graph=None,
    input_saved_model_dir=None,
    saved_model_tags=tf.saved_model.tag_constants.SERVING,
    checkpoint_version=saver_pb2.SaverDef.V2)

  
  
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26

其中model_path是模型保存路径,GRAPH_PB_NAME定义了图模型的名字。

freeze_graph主要参数(参考[4]博客中的参数说明):

  • input_graph : 模型文件,可以是二进制的pb文件,或文本的meta文件,用input_binary来指定区分。
  • input_checkpoint : 检查点数据文件。
  • output_node_names : 输出节点的名字,有多个时用逗号分开。
  • output_graph : 保存整合后的输出模型。

2. 读取ckpt模型

在我的模型中,要求的输入有四个,分别是inputs_vocab,inputs_feature_list,sequence_length,max_length
计算得到的输出有两个viterbi_sequenceintent_prediction

ckpt = tf.train.get_checkpoint_state(arg.model_path + '/')
saver = tf.train.import_meta_graph(ckpt.model_checkpoint_path + '.meta')    
with tf.Session() as sess:
    saver.restore(sess, ckpt.model_checkpoint_path)
    graph = tf.get_default_graph()   
    # 加载模型中的操作节点	
    inputs_vocab = graph.get_operation_by_name('inputs_vocab').outputs[0]
    feature_data_list = graph.get_operation_by_name('inputs_feature_list').outputs[0]
    sequence_length = graph.get_operation_by_name('sequence_length').outputs[0]
    max_length = graph.get_operation_by_name('max_length').outputs[0]
    # 准备测试数据(略)
    # in_data = ...
    # fea_data_list = ...
    # length = ...
    # max_len = ...
    # feed 数据
    feed_dict = {inputs_vocab.name: in_data,
                 feature_data_list.name: fea_data_list,
                 sequence_length.name: length,
                 max_length.name: max_len}  
    # 计算
    viterbi_sequence = sess.run('viterbi_sequence:0', feed_dict)
    intent_prediction = sess.run('intent_prediction:0', feed_dict)      

  
  
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23

3. 读取frozen graph模型

# 读取图文件
with tf.gfile.FastGFile('./model/frozen_graph.pb', 'rb') as f:
    graph_def = tf.GraphDef()
    graph_def.ParseFromString(f.read())
    # We load the graph_def in the default graph
    with tf.Graph().as_default() as graph:
        tf.import_graph_def(
            graph_def,
            input_map=None,
            return_elements=None,
            name="",
            op_dict=None,
            producer_op_list=None
        )
        with tf.Session() as sess:
            # 根据名称返回tensor数据
            inputs_vocab = graph.get_tensor_by_name('inputs_vocab:0')
            feature_data_list = graph.get_tensor_by_name('inputs_feature_list:0')
            sequence_length = graph.get_tensor_by_name('sequence_length:0')
            max_length = graph.get_tensor_by_name('max_length:0')
            # 准备测试数据(略)
            # in_data = ...
            # fea_data_list = ...
            # length = ...
            # max_len = ...
            # feed 数据
            feed_dict = {inputs_vocab.name: in_data,
                         feature_data_list.name: fea_data_list,
                         sequence_length.name: length,
                         max_length.name: max_len}
            # 计算结果
            viterbi_sequence = graph.get_tensor_by_name('viterbi_sequence:0')
            intent_prediction = graph.get_tensor_by_name('intent_prediction:0')
            viterbi_sequence = sess.run(viterbi_sequence, feed_dict)
            intent_prediction = sess.run(intent_prediction, feed_dict)

  
  
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35

注意,这里如果不使用上下文管理器Graph().as_default(),在进行预测的时候可能会报"The Session graph is empty. Add operations to the graph before calling run()…"的错误。

参考博客:

  1. tensorflow将训练好的模型freeze,即将权重固化到图里面,并使用该模型进行预测
  2. 使用TensorFlow C++ API构建线上预测服务
  3. Tensorflow加载预训练模型和保存模型
  4. tensorflow,使用freeze_graph.py将模型文件和权重数据整合在一起并去除无关的Op

猜你喜欢

转载自blog.csdn.net/l641208111/article/details/105291117