ubuntu下将tensorflow训练好的模型移植到安卓端

1.下载tensorflow源码

git clone --recurse-submodules https://github.com/tensorflow/tensorflow.git

2.通过export_tflite_ssd_graph.py将训练后的模型导出所需要的文件

配置参数为

--pipeline_config_path=/home/jinyan/anaconda3/envs/tensorflow/models/research/object_detection/sm_products/MODEL_13/sm_pipeline.config
--trained_checkpoint_prefix=/home/jinyan/anaconda3/envs/tensorflow/models/research/object_detection/sm_products/MODEL_13/train/model.ckpt-39745
--output_directory=/home/jinyan/anaconda3/envs/tensorflow/models/research/object_detection/sm_products/MODEL_13/model_lite
--add_postprocessing_op=true

运行后将在output_directory目录生成 tflite_graph.pb 和 tflite_graph.pbtxt 两个文件。

3.下载bazel

下载地址:
https://docs.bazel.build/versions/master/install.html

4.编译转换工具

# 一定要进到tensorflow目录下,在有WORKSPACE的目录下进行编译
cd tensorflow
bazel build tensorflow/python/tools:freeze_graph
bazel build tensorflow/lite/toco:toco 

(注:新旧版本编译工具的位置不同)

要是编译成功,过程和结果如下所示:
在这里插入图片描述
在这里插入图片描述

5.用TOCO工具将训练好的模型从.pb格式转换为.tflite格式

注意:一定要进入到WORKSPACE所在的文件夹目录下进行下面命令操作
浮点型:

bazel run -c opt tensorflow/lite/toco:toco -- \
--input_file=/home/jinyan/anaconda3/envs/tensorflow/models/research/object_detection/sm_products/MODEL_13/model_lite/tflite_graph.pb \
--output_file=/home/jinyan/anaconda3/envs/tensorflow/models/research/object_detection/sm_products/MODEL_13/detect/detect.tflite \
--input_shapes=1,300,300,3 \
--input_arrays=normalized_input_image_tensor \
--output_arrays='TFLite_Detection_PostProcess','TFLite_Detection_PostProcess:1','TFLite_Detection_PostProcess:2','TFLite_Detection_PostProcess:3'  \
--inference_type=FLOAT \
--allow_custom_ops

整数型:

bazel run --config=opt tensorflow/lite/toco:toco -- \
--input_file=/home/jinyan/anaconda3/envs/tensorflow/models/research/object_detection/sm_products/MODEL_13/model_lite/tflite_graph.pb \
--output_file=/home/jinyan/anaconda3/envs/tensorflow/models/research/object_detection/sm_products/MODEL_13/detect/detect.tflite \
--input_shapes=1,300,300,3 \
--input_arrays=normalized_input_image_tensor \
--output_arrays='TFLite_Detection_PostProcess','TFLite_Detection_PostProcess:1','TFLite_Detection_PostProcess:2','TFLite_Detection_PostProcess:3' \
--inference_type=QUANTIZED_UINT8 \
--mean_values=128 \
--std_values=128 \
--change_concat_input_ranges=false \
--allow_custom_ops

若是转换成功,会产生 detect.tflite 文件。

6.下载AndroidStudio

详细过程可以参考另一篇博客:https://blog.csdn.net/weixin_43843657/article/details/88655239

7.将生成的tflit文件移植到AndroidStudio上

项目工程文件:https://github.com/tensorflow/tensorflow/tree/master/tensorflow/lite/examples/android
直接将该路径import到 Android Studio 里面
将detect.tflite文件拷贝到 android/app/src/main/assets 目录下,将训练模型的类别以txt文件的形式记录并存到assets下
在这里插入图片描述
对 DetectorActivity.java 进行修改,将物体检测模型替换成我们生成的模型(detect.tflite),并将模型对应的分类列表文件也进行替换;若是前面是用浮点数形式生成的,TF_OD_API_IS_QUANTIZED = false记得改正。
在这里插入图片描述
重新编译运行,即可生成测试APK,最终的安卓安装包路径:
在这里插入图片描述
查看build.gradle文科,可以确定该demo指定的目标运行系统的api level 为26在这里插入图片描述
若没有api level为26的模拟器就新建一个:
在这里插入图片描述
成功就会在安卓机上有个app,结束!
(若是碰上安装的app存在闪退现象,可尝试清理下手机内存)

注:官方网站教程 https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/running_on_mobile_tensorflowlite.md
  1. 因为类别数组越界造成的app闪退,需进行以下代码修改:
    在这里插入图片描述
  2. 若出现框的score大于1的情况,需进行以下代码修改:
    在这里插入图片描述
  3. 在没有被检测物体时仍然有框

猜你喜欢

转载自blog.csdn.net/weixin_43843657/article/details/88906560