简介
在桌面PC或是服务器上使用TensorFlow训练出来的模型文件,不能直接用在TFLite上运行,需要使用离线工具先转成.tflite文件。笔者发现官方文档中很多细节介绍的都不太明确,在使用过程中需要不断尝试。我把自己的尝试过的步骤分享出来,希望能帮助大家节省时间。
具体说来,tflite文件的生成大致分为3步:
1. 在算法训练的脚本中保存图模型文件(GraphDef)和变量文件(CheckPoint)。
2. 利用freeze_graph工具生成frozen的graphdef文件。
3. 利用toco(Tensorflow Optimizing COnverter)工具,生成最终的tflite文件。
图1. 生成tflite文件的整个流程示意图
第1步:导出图模型文件和变量文件
在你的算法的训练或推理任务的脚本中,利用tensorflow.train中的write_graph和saver API来导出GraphDef及Checkpoint文件。
这样我们可以拿到模型的pb文件或ckpt文件
第2步:freeze graph
可参考Tensorflow C++ API线上预测服务的文档。即我的另一篇博客《使用TensorFlow C++ API构建线上预测服务》https://blog.csdn.net/lsj1342/article/details/82752951
第3步:生成最终的tflite文件
构建toco工具
bazel build tensorflow/contrib/lite/toco:toco
构建成功后,可在tensorflow/bazel-bin/tensorflow/contrib/lite/toco中看到可执行文件toco,将路径加入环境变量
toco --input_file=/data/liusijia/tensorflowTest/freeze_graph.pb \ //路径自己指定,这里是freeze_graph后pb文件
--input_format=TENSORFLOW_GRAPHDEF \
--output_format=TFLITE \
--output_file=/data/liusijia/tensorflowTest/tmp.tflite \ //路径自己指定
--inference_type=FLOAT \
--input_type=FLOAT \
--input_arrays=Placeholder \ //输入节点的名称
--output_arrays=Softmax \ //输出节点的名称
--input_shapes=1,784 //输入节点的维度