tensorflow深度学习实战笔记(二):把训练好的模型进行固化

目录

一、导出前向传播图

二、对模型进行固化

三、pb文件转tflite文件


说明:接我的上一篇博客:tensorflow深度学习实战笔记(一):使用tensorflow slim自带的模型训练自己的数据
 

前面讲解了如何用tensorflow slim训练自己的模型,现在讲解如何把训练好的模型的cpkt文件固化成pb文件(即最终的模型)

一、导出前向传播图

在slim文件夹下有export_inference_graph.py文件(slim文件夹的位置可以参考我的上一篇博客),运行该脚本即可导出前向传播图,运行方式如下:

python export_inference_graph.py \
  --alsologtostderr \
  --dataset_dir=/home/yuping-chen/slim/my_data/fruits/ \#数据集的路径
  --dataset_name=fruit \#数据集的名字
  --model_name=inception_v3 \#导出的模型
  --image_size=224 \#图片尺寸
  --output_file=my_model/inception_v3_inf.pb#输出文件名,可以自定义

运行结束后即可在你指定的位置生成相应的pb文件,这个文件是前向传播图,并没有参数,所以也不是最终的模型,因此文件较小。

二、对模型进行固化

对上面生成的前向传播图进行固化,即把cpkt文件的参数导入到前向传播图中得到最终的模型,固化方式有两种,一是使用tensorflow自带的脚本,二是使用bazel工具。

两种方式差不多,只是使用的方式不一样,只有第一行不一样,其他的参数都是一样的。

2.1使用tensorflow自带的脚本进行固化

python -u /usr/local/lib/python2.7/dist-packages/tensorflow/python/tools/freeze_graph.py \
  --input_graph=/home/yuping-chen/slim/my_model/inception_v3_inf.pb \#上一步的前向传播图
  --input_checkpoint=/home/yuping-chen/slim/my_save_model/fruit-models/inception_v3_140/model.ckpt-68074 \#自己训练的cpkt文件
  --output_graph=/home/yuping-chen/slim/my_model/frozen_model/frozen_inception_v3.pb \#最终的模型
  --input_binary=True \
  --output_node_name=InceptionV3/Predictions/Reshape_1 

2.2使用bazel工具进行固化

bazel的具体安装和使用请参考我的另一篇博客:tensorflow深度学习实战笔记(四):bazel编译tensorflow工具的使用方法

bazel-bin/tensorflow/python/tools/freeze_graph \
  --input_graph=/home/yuping-chen/slim/my_model/inception_v3_inf.pb \#上一步的前向传播图
  --input_checkpoint=/home/yuping-chen/slim/my_save_model/fruit-models/inception_v3_140/model.ckpt-68074 \#自己训练的cpkt文件
  --output_graph=/home/yuping-chen/slim/my_model/frozen_model/frozen_inception_v3.pb \#最终的模型
  --input_binary=True \
  --output_node_name=InceptionV3/Predictions/Reshape_1 

运行结束后即可在你指定的位置生成相应的pb文件,这个文件是最终的模型,和前向传播图不同,治理包含了参数,因此文件会比较大。

2.3可能会遇到的问题

如果抛出诸如“lrs=[5],hrs[10]”类的错误([ ]中的内容也许会不同),说明cpkt输出类别数目10和前向传播图类别数目5不相等造成的,这是要检查你导入前向传播图的参数了,确认数据集类别数正确后在执行固化则不会报错。

 

三、pb文件转tflite文件

tflite模型相比于pb要精简的多,只是针对嵌入式平台进行的优化,对移植到嵌入式平台建议使用该方式,因为执行速度会更快,但是tflite并不是所有模型都支持,只支持部分模型。

转换形式使用bazel,bazel的具体安装和使用请参考我的另一篇博客:tensorflow深度学习实战笔记(四):bazel编译tensorflow工具的使用方法

bazel-bin/tensorflow/contrib/lite/toco/toco \
  --input_file=/home/yuping-chen/slim/my_model/frozen_model/frozen_inception_v3.pb \
  --input_format=TENSORFLOW_GRAPHDEF \
  --output_format=TFLITE \
  --output_file=/home/yuping-chen/slim/my_model/frozen_model/frozen_inception_v3.tflite \
  --inference_type=FLOAT \
  --input_type=FLOAT \
  --input_arrays=input \
  --output_arrays=Inception/Predictions/Reshape_1 \
  --input_shapes=1,299,299,3

运行结束后即可在你指定的位置生成相应的tflite文件,后面就可以用生成的tflite文件移植到手机端了,tensorflow lite手机端的移植方法可以参考我的另一篇博客:tensorflow深度学习实战笔记(三):使用tensorflow lite把训练好的模型移植到手机端,编译成apk文件

 

 

猜你喜欢

转载自blog.csdn.net/chenyuping333/article/details/82106863