tensorflow 模型保存与导入

Tensorflow 模型保存与导入


使用Tensorflow训练模型,将已经训练好的模型保存起来,留下接口供以后使用,或者由其他程序模型来调用,有三种方法,其中使用PB文件最为方便。

  • 先定义模型,需要有模型定义的源码,然后使用saver.restore(sess, checkpoint)。
  • 不需要定义模型,使用import_meta_graph。
  • 不需要定义模型,使用PB文件。

模型保存

模型保存过程

saver = tf.train.Saver(max_to_keep=256)

saver.save(sess, model_path, global_step=epoch_comp) # save model.
  • 首先需要定义saver类型,其中参数为max_to_keep 表示需要保存的模型个数。
  • 然后在训练最后使用saver.save()即可保存模型,需要注明model_path和global_step,global_step用来注明是哪一个 step 或者 epoch 保存的模型。
  • ​Tensorflow的计算图在运行时,以 MetaGrapDef 的形式实现计算图。在进行计算图保存时,将 MetaGraphDef 以二进制的形式写入磁盘,在保存模型产生的3个的文件中,MetaGraphDef 保存在.meta文件中;其中模型经过训练的模型参数,权重,可训练的变量保存在.data文件中;张量名到张量的对应映射关系保存在.index文件中。

PB文件保存

谷歌推荐的保存方式是保存为PB文件,它具有语言独立性,任何语言都可以解析它,允许其他语言和深度学习框架读取、继续训练和迁移 TensorFlow 的模型。

  • PB文件的保存主要是导入graph结构和恢复权重的过程。其中import_meta_graph导入graph结构,restore是恢复权重,as_graph_def()返回此图的序列化图形化表示。具体代码如下。
  • PB文件保存有一个重要参数是output_node_names,即注明模型输出的节点。这里就能感受到对tensor手动命名的好处了。
# 定义路径
ckpt = "model/epoch-1"
meta_file_path = 'epoch1.meta'
pb_file_path ='epoch1.pb'
#转pb
saver = tf.train.import_meta_graph(meta_file_path)
with tf.Session() as sess:
    saver.restore(sess, ckpt )
    graph = tf.get_default_graph()
    input_graph_def = sess.graph.as_graph_def()
    # 下面的代码用来检查节点命名
    for node in input_graph_def.node:
        print("node.name = ", node.name)
    #
    output_graph_def = tf.graph_util.convert_variables_to_constants(sess, 
    																graph.as_graph_def(),
                                                                    output_node_names=[output_node])
    tf.train.write_graph(output_graph_def, '', pb_file_path, as_text=False)

模型导入

猜你喜欢

转载自blog.csdn.net/s09094031/article/details/105394051