整个流程稍微有点长,但是如果走通后再来新的任务就很简单了。大概分为以下几个步骤:
- 训练SSD
- 转tflite
- 在android应用中运行
训练SSD
1.代码下载
https://github.com/tensorflow/models.git
我们会使用ROOT/research/object_detection下的代码来训练SSD。
2.环境配置
除了tensorflow外还需要安装下面这些模块
pip install --user Cython
pip install --user contextlib2
pip install --user pillow
pip install --user lxml
pip install --user jupyter
pip install --user matplotlib
也可以后面使用的时候根据错误提示,缺了什么再安装对应的模块。
另外还需要配置protoc,下载地址,下载后解压放在ROOT/research目录下,将文件夹的名字更改为protoc,然后运行
./protoc/bin/protoc object_detection/protos/*.proto --python_out=.
运行之后在ROOT/research/object_detection/protos/下会生成很多python文件,后面代码中会使用这些python文件。
为了验证我们环境或者python文件是否正确,可以运行下面的命令
export PYTHONPATH=$PYTHONPATH:`pwd`:`pwd`/slim
python object_detection/builders/model_builder_test.py
如果出现下面的结果就表示准备好了
......................
----------------------------------------------------------------------
Ran 22 tests in 0.191s
OK
3.准备训练数据
标记数据
使用labelImg工具(下载地址)对数据进行标记,会对每张图片生成一个xml文件。
将标记后的图片分为train和test数据,分别放在train和test文件夹下,xml文件要跟着图片一起。
将xml和图片转成tfrecord
先将xml转为csv格式,可以参考xml_to_csv.py文件,根据自己的情况把PROJECT改成自己的目录地址就好了,然后运行这个python文件,可以得到train_labels.csv和test_labels.csv。
再将图片和csv一起转换为tfrecord文件,可以参考generate_tfrecord.py文件。运行
python generate_tfrecord.py --dataset=boxes
这里的dataset是我对自己的数据集合的命名。另外如果有不同的类别需要在class_text_to_int函数中增加自己的类别,我这边只识别一个类别。
def class_text_to_int(row_label):
if row_label == 'box':
return 1
//if row_label == 'class2':
//return 2
//if row_label == 'class3':
//return 3
else:
None
4.开始训练
创建label文件
参考pbtxt 目录下的文件写自己的pbtxt文件,这个文件就是写出自己有多少个类别需要识别,比如两个类别可以这样写
item {
id: 1
name: 'car'
}
item {
id: 2
name: 'pedestrian'
}
创建config文件
config文件主要是配置一些训练的参数,如果类别,batch size,优化方式,anchor生成的方式等等。
可以参考config参考文件中的congfig文件进行修改,这个目录下有很多网络的config文件,我这边使用了ssd_mobilenet_v2_coco.config作为base进行了修改,主要改动以下几个地方:
model {
ssd {
num_classes: 1 ---> number of classes
box_coder {
...
train_input_reader: {
tf_record_input_reader {
input_path: "project_images/boxes/train.tfrecord" ---> train tfrecored path
}
label_map_path: "project_images/boxes/labelmap.pbtxt" ---> train label path
}
...
eval_input_reader: {
tf_record_input_reader {
input_path: "project_images/boxes/test.tfrecord" ---> test tfrecored path
}
label_map_path: "project_images/boxes/labelmap.pbtxt" ---> test label path
shuffle: false
num_readers: 1
}
训练
使用legacy/train.py进行训练,只需要指出训练后的数据存放在哪里,以及config文件的路径
python legacy/train.py --logtostderr \
--train_dir=project_images/boxes/logs \
--pipeline_config_path=project_images/boxes/config/ssd_mobilenet_v2.config
当我们训练一段时间后,在logs目录下会得到一些ckpt文件。
model.ckpt-7386.data-00000-of-00001
model.ckpt-7386.index
model.ckpt-7386.meta
生成tflite文件
1.生成frozen pb文件
在生成tflite文件之前,我们需要将ckpt文件的ckpt-data与ckpt-meta一起生成pb文件,就是将计算图与参数的值融合起来生成一个文件。
python object_detection/export_tflite_ssd_graph.py \
--pipeline_config_path=object_detection/project_images/boxes/config/ssd_mobilenet_v2.config \
--trained_checkpoint_prefix=object_detection/project_images/boxes/logs/model.ckpt-7386 \
--output_directory=object_detection/project_images/boxes/logs/tflite \
--add_postprocessing_op=true
会在tflite目录下生成以下文件:
tflite_graph.pb
tflite_graph.pbtxt
2.生成tflite文件
用生成的tflite_graph.pb文件来生成tflite文件,需要使用到toco工具,如果是第一次使用需要编译toco工具。
下载tensorflow源码:
git clone https://github.com/tensorflow/tensorflow.git
编译toco:
bazel build tensorflow/python/tools:freeze_graph
bazel build tensorflow/lite/toco:toco
生成tflite文件:
bazel run tensorflow/lite/toco:toco -- --input_file=/local/deeplearning/models/research/object_detection/project_images/boxes/logs/tflite/tflite_graph.pb --output_file=/local/deeplearning/models/research/object_detection/project_images/boxes/logs/tflite/ssd_mobilenetv2.tflite --input_shapes=1,300,300,3 --input_arrays=normalized_input_image_tensor --output_arrays='TFLite_Detection_PostProcess','TFLite_Detection_PostProcess:1','TFLite_Detection_PostProcess:2','TFLite_Detection_PostProcess:3' --inference_type=FLOAT --allow_custom_ops
运行上面的命令后会生成ssd_mobilenetv2.tflite文件。
在android中应用tflite文件
可以参考tensorflow中的demo来实现,tensorflow demo
首先将生成的tflite文件拷贝在assets目录下,并且在这个目录下创建一个txt文件,内容就是需要识别的类别
???
boxes
第一行写???表明这一类是background,如果不这样写似乎apk运行会crash。
接着修改BUILD文件,将tflite文件和txt文件都在assets中指明。
android_binary(
name = "boxes",
srcs = glob([
"app/src/main/java/**/*.java",
]),
aapt_version = "aapt",
# Package assets from assets dir as well as all model targets.
# Remove undesired models (and corresponding Activities in source)
# to reduce APK size.
assets = [
"//tensorflow/lite/examples/boxes/app/src/main/assets:ssd_mobilenetv2.tflite",
"//tensorflow/lite/examples/boxes/app/src/main/assets:ssd_mobilenetv2_labels.txt",
],
......
还需要在DetectorActivity.java中修改
private static final String TF_OD_API_MODEL_FILE = "ssd_mobilenetv2.tflite";
private static final String TF_OD_API_LABELS_FILE = "ssd_mobilenetv2_labels.txt";
最后使用bazel进行编译就可以生成apk文件了
bazel build -c opt --cxxopt='--std=c++11' --fat_apk_cpu=x86,x86_64,arm64-v8a,armeabi-v7a \
//tensorflow/lite/examples/android:tflite_demo
生成的apk文件可以根据需求进行物体定位。