接上篇【Tensorflow】object_detection:SSD_MobileNetV2训练VOC数据集训练完成之后,生成了一系列ckpt文件,现在来将训练生成的ckpt模型固化。
tensorflow工具的编译请参考链接:【Tensorflow】bazel编译tensorflow工具summarize_graph、freeze_graph、toco
一、模型固化
以下命令的执行路径是models/research
使用的的工具是object_detection下面的export_tflite_ssd_graph.py,命令是:
python object_detection/export_tflite_ssd_graph.py \
--pipeline_config_path=/home/models/research/object_detection/samples/configs/ssd_mobilenet_v2_pascal.config \
--trained_checkpoint_prefix=/home/data/VOCdekit/ssd_mobilnet_v2/model.ckpt-50000 \
--output_directory=/home/data/VOCdekit/ssd_mobilnet_v2/output \
--add_postprocessing_op=true
注意在执行命令之前要将slim加入环境变量
export PYTHONPATH=$PYTHONPATH:`pwd`:`pwd`/slim
完成之后会生成tflite_graph.pb 和tflite_graph.pbtxt两个文件。
二、toco优化
以下命令的执行路径是tensorflow根目录。
用toco工具将固化的模型转换成tflite文件可以在移动端使用。
bazel run --config=opt /tensorflow/contrib/lite/toco:toco -- \
--input_file=/home/data/VOCdekit/ssd_mobilnet_v2/output/tflite_graph.pb \
--output_file=/home/data/VOCdekit/ssd_mobilnet_v2/output/detect.tflite \
--input_shapes=1,300,300,3 \
--input_arrays=normalized_input_image_tensor \
--output_arrays='TFLite_Detection_PostProcess','TFLite_Detection_PostProcess:1','TFLite_Detection_PostProcess:2','TFLite_Detection_PostProcess:3' \
--inference_type=FLOAT \
--allow_custom_ops
完成之后生成了detect.tflite文件。