tensorflow之freeze_gragh

主要了解下freeze_graph的用法

以及了解下freeze_graph_test的一些相关知识(据说具有很好的学习价值)

freeze_graph.py源码链接:

https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/tools/freeze_graph.py

freeze_graph_test.py源码链接:

https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/tools/freeze_graph_test.py

tf模型的基本介绍

Tensorflow所有的文档格式都是基于Protocol Buffer,即protobuf

在文本文档中定义数据结构,protobuf工具生成C、Python和其他语言的类,这些类可以友好的加载、保存和方位数据

Tensorflow里的计算基础是Graph对象

它可以存储网络节点,每一个节点代表一个操作,并作为输入和输出相互链接在一起

GraphDef 类是ProtoBuf根据

https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/framework/graph.proto

以此为基础定义创建的对象。Protobuf工具会解析此文本文档,并生成用户加载、存储和操控图定义的代码。

将文档加载到 grapf_def 变量中,就可以访问其中的数据

可以使用下面的代码来遍历这些节点,基本上重要的部分都是存储在节点中了。

1
2
graph_def = graph_pb2.GraphDef()
for node in graph_def.node

每一个节点都是一个在node_def.proto定义的NodeDef对象,这些节点是Tensorflow图的基本构建块,每一个构建块都定义了一个操作以及其输入连接。NodeDef的成员如下所示:

  • name 节点的唯一标识符,该标识符不会被途中的任何其他节点使用
  • op 定义了要运行的操作,比如Add、MatMul、Conv2D
  • input 字符列表表,每个字符串都是另一个节点的名称,比如两个输入[“input_1:0”,”input_2:0”]
  • device 定义了在分布式环境中运行的位置
  • attr 包含某个节点的所有属性的键值对存储区

以上成员都可以通过 node.name node.op等来访问

因为tf在训练期间权重通常不会存储在文档格式内,而是保存在单独的检查点中,并且图中的Variable操作可在初始化操作时加载最新的值。

在部署到生产环境时,使用单独的文档往往不是很方便,因此我们需要一个脚本 freeze_graph.py

将这些检查点、文档冻结到一个文档中。

具体操作就是加载GraphDef,从最新的检查点文档中提取所有变量的值,然后将每个Variable操作替换为Const(其中包含存储在其属性中的权重的数值数据)。然后,它会剥离所有未用于前向推断的无关节点,并将生成的GraphDef保存到输出文档中。

freeze_graph.py

先了解下参数:

  • input_graph 模型文档,二进制pb或者文本meta,用input_binary来区分
  • input_saver 需要加载的Tensorflow saver文档
  • input_checkpoint 检查点文档,用于模型恢复变量值
  • checkpoint_version 变量文档的格式 (saver_pb2.SaverDef.V1 or saver_pb2.SaverDef.V2)
  • output_graph 冻结完成后的写入路径
  • input_binary 输入文档是否是二进制 True Or False
  • output_node_names 输出节点的名字,多个节点用逗号分隔
  • restore_op_name 已废弃
  • filename_tensor_name 已废弃
  • clear_devices 默认是True,是否清楚训练节点的设备
  • initializer_nodes 需要初始化的节点
  • variable_names_whitelist 指定需要恢复的变量
  • variable_names_blacklist 指定不用恢复的变量
  • input_meta_graph 需要加载的MetaGraphDef
  • input_saved_model_dir SavedModel文档和变量的路径
  • saved_model_tags 加载MetaGraphDef中的tag组,逗号分隔(MetaGraphDef中可以用tags来区分不同的计算图)

首先解析checkpoint版本:

1
2
3
4
5
6
7
if flags.checkpoint_version == 1:
checkpoint_version = saver_pb2.SaverDef.V1
elif flags.checkpoint_version == 2:
checkpoint_version = saver_pb2.SaverDef.V2
else:
raise ValueError("Invalid checkpoint version (must be '1' or '2'): %d" %
flags.checkpoint_version)

两种checkpoint的保存方法如下:

v1 v2
model.ckpt-0001 model.ckpt-0001.index
model.ckpt-0001.meta model.ckpt-0001.meta
model.ckpt-0001.data-00000-of-00001

然后解析输入的graphDef:

continue…

参考:

https://www.tensorflow.org/guide/extend/model_files#freezing

https://blog.csdn.net/czq7511/article/details/72452985

原文链接 大专栏  https://www.dazhuanlan.com/2019/08/24/5d612ad619bfd/

猜你喜欢

转载自www.cnblogs.com/chinatrump/p/11415213.html