ckpt转pb

有时需要将tensorflow训练得到的ckpt固化成pb使用。Tensorflow提供了相关的固化命令脚本,下面是我搜集到的转换步骤。参考https://www.yanxishe.com/columnDetail/15278

一、tensorflow保存下的内容
用 tf.train.Saver.save() 方式保存下来的checkpoint会产生四个文件:
在这里插入图片描述

checkpoint
记录了部分已存储和最近存储的模型:
model_checkpoint_path: “mtcnn-3000000”
all_model_checkpoint_paths: “mtcnn-3000000”

mtcnn-3000000.data-00000-of-00001
保存了模型的所有变量的值,TensorBundle集合。

model.ckpt.index
string-string的映射表,映射表的key值为tensor名,value为serialized BundleEntryProto,每个BundleEntryProto表述了tensor的metadata。

model.ckpt.meta
保存了graph结构,包括 GraphDef,SaverDef等,当存在meta file,我们可以不在文件中定义模型,也可以运行,而如果没有meta file,我们需要定义好模型,再加载data file,得到变量值。

二、固化命令
进入到当前环境安装目录,我的环境名称叫tensorflow-gpu,找到虚拟环境下tensorflow-gpu/python/tools/,执行命令
Python freeze_graph.py
–input_meta_graph=model.ckpt.meta
–input_checkpoint=model.ckpt
–output_graph=frozen_graph_meta.pb
–output_node_name=embeddings
–input_binary=True
问题一、可能遇到的错误
UnicodeDecoderError: ‘utf-8’ codec can’t decode byte 0xd8 in position 1: invalid continuation byte
解决:传入参数 --input_binary=True

问题二、当我们手上只有别人训练好的模型文件时,如何确定输入参数中的–output_node_name呢?
最直观的方式是使用TensorBoard查看图结构,步骤如下:
(1)从.meta文件生成Tensorboard所需的log,可以通过以下代码生成,运行脚本后在log目录下生成events.out.tfevents.xxx文件。

import tensorflow as tf
import os
def write_graph_log(meta_file, log_dir):
if not os.path.exists(log_dir):
os.mkdir(log_dir)
g = tf.Graph()
with g.as_default() as g:
tf.train.import_meta_graph(meta_file)
with tf.Session(graph=g) as sess:
tf.summary.FileWriter(logdir=log_dir, graph=g)
if name == ‘main’:
write_graph_log(‘model.ckpt.meta’, ‘./log/’)
(2) 在Windows下通过cmd命令行启动Tensorboard

cd model_dir # 进入模型文件所在的目录
tensorboard --logdir=log # 启动tensorboard,指定log目录
(3)浏览器打开tensorflow显示的网址(一般为http://127.0.0.1:6006),通过可视化的图结构可以清楚地看到输入和输出节点的名字。

猜你喜欢

转载自blog.csdn.net/chenaxin/article/details/120215776