小白带你用TensorFlow Lite对象检测
首先先了解一下TensorFlow Lite是什么?
概述
通过摄像头的应用程序,可以是手机上的摄像头或者是可以带动摄像头的终端设备,它使用在COCO数据集上训练的量化MobileNet SSD模型,连续检测设备后置摄像头看到的帧中的对象(边界框和类)。这些说明将引导您在Android设备上构建和运行演示程序。
模型文件在构建和运行时通过Gradle脚本下载。不需要执行任何步骤即可将TFLite模型显式下载到项目中。
应用程序可以在设备或模拟器上运行
Android Studio
先决条件
- 安装Android Studio。
- Android开发环境,最小API 21。
- Android Studio 3.2款或者以上版本。
复制代码
打开Android Studio,并从欢迎屏幕中选择打开现有Android Studio项目。
从出现的“打开文件或项目”窗口中,导航到tensorflow lite/examples/object_detection/android目录,并从克隆tensorflow lite示例GitHub repo的位置选择该目录。单击“确定”。
git clone https://github.com/tensorflow/models/
编译和安装
cd models
python3 setup.py build && python3 setup.py install
需要下载训练集
在此处下载在Open Images v4上训练过的MobileNet固态硬盘。提取预训练的TensorFlow模型文件。
转到models/research目录并执行此代码以获取冻结的TensorFlow Lite图。
python3 object_detection/export_tflite_ssd_graph.py \
--pipeline_config_path object_detection/samples/configs/ssd_mobilenet_v2_oid_v4.config \
--trained_checkpoint_prefix <directory with ssd_mobilenet_v2_oid_v4_2018_12_12>/model.ckpt \
--output_directory exported_model
将冻结的图形转换为TFLite模型
tflite_convert \
--input_shape=1,300,300,3 \
--input_arrays=normalized_input_image_tensor \
--output_arrays=TFLite_Detection_PostProcess,TFLite_Detection_PostProcess:1,TFLite_Detection_PostProcess:2,TFLite_Detection_PostProcess:3 \
--allow_custom_ops \
--graph_def_file=exported_model/tflite_graph.pb \
--output_file=<directory with the TensorFlow examples repository>/lite/examples/object_detection/android/app/src/main/assets/detect.tflite
- input_shape=1,300m300,3,预训练模型仅适用于该输入形状。
- allow_custom_ops 是允许TFLite_Detection_后处理操作所必需的。
- 可以从示例检测模型的可视化图形中绘制输入数组和输出数组。
bazel run //tensorflow/lite/tools:visualize \
"<directory with the TensorFlow examples repository>/lite/examples/object_detection/android/app/src/main/assets/detect.tflite" \
detect.html