Tensorflow 模型持久化操作

模型保存方法

tf.train.Saver()

tf.train.write_graph


Tensorflow 官方代码

https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/tools/freeze_graph.py

总共有11个参数,一个个介绍下(必选: 表示必须有值;可选: 表示可以为空):

1、input_graph:(必选)模型文件,可以是二进制的pb文件,或文本的meta文件,用input_binary来指定区分(见下面说明)

2、input_saver:(可选)Saver解析器。保存模型和权限时,Saver也可以自身序列化保存,以便在加载时应用合适的版本。主要用于版本不兼容时使用。可以为空,为空时用当前版本的Saver。

3、input_binary:(可选)配合input_graph用,为true时,input_graph为二进制,为false时,input_graph为文件。默认False

4、input_checkpoint:(必选)检查点数据文件。训练时,给Saver用于保存权重、偏置等变量值。这时用于模型恢复变量值。

5、output_node_names:(必选)输出节点的名字,有多个时用逗号分开。用于指定输出节点,将没有在输出线上的其它节点剔除。

6、restore_op_name:(可选)从模型恢复节点的名字。升级版中已弃用。默认:save/restore_all

7、filename_tensor_name:(可选)已弃用。默认:save/Const:0

8、output_graph:(必选)用来保存整合后的模型输出文件。

9、clear_devices:(可选),默认True。指定是否清除训练时节点指定的运算设备(如cpu、gpu、tpu。cpu是默认)

10、initializer_nodes:(可选)默认空。权限加载后,可通过此参数来指定需要初始化的节点,用逗号分隔多个节点名字。

11、variable_names_blacklist:(可先)默认空。变量黑名单,用于指定不用恢复值的变量,用逗号分隔多个变量名字。

用法:

例:python tensorflow/python/tools/free_graph.py \
--input_graph=some_graph_def.pb \ 注意:这里的pb文件是用tf.train.write_graph方法保存的
--input_checkpoint=model.ckpt.1001 \ 注意:这里若是r12以上的版本,只需给.data-00000....前面的文件名,如:model.ckpt.1001.data-00000-of-00001,只需写model.ckpt.1001  
--output_graph=/tmp/frozen_graph.pb 

--output_node_names=softmax

另外,如果模型文件是.meta格式的,也就是说用saver.Save方法和checkpoint一起生成的元模型文件,free_graph.py不适用,但可以改造下:

1、copy free_graph.py为free_graph_meta.py

2、修改free_graph.py,导入meta_graph:from tensorflow.python.framework import meta_graph

3、将91行到97行换成:input_graph_def = meta_graph.read_meta_graph_file(input_graph).graph_def


这样改即可加载meta文件




猜你喜欢

转载自blog.csdn.net/u012968002/article/details/80352079