使用tensorflow自带model训练SSD并且在手机上运行

整个流程稍微有点长,但是如果走通后再来新的任务就很简单了。大概分为以下几个步骤:

  1. 训练SSD
  2. 转tflite
  3. 在android应用中运行

训练SSD

1.代码下载

https://github.com/tensorflow/models.git

我们会使用ROOT/research/object_detection下的代码来训练SSD。

2.环境配置

除了tensorflow外还需要安装下面这些模块

pip install --user Cython
pip install --user contextlib2
pip install --user pillow
pip install --user lxml
pip install --user jupyter
pip install --user matplotlib

也可以后面使用的时候根据错误提示,缺了什么再安装对应的模块。

另外还需要配置protoc,下载地址,下载后解压放在ROOT/research目录下,将文件夹的名字更改为protoc,然后运行

./protoc/bin/protoc object_detection/protos/*.proto --python_out=.

运行之后在ROOT/research/object_detection/protos/下会生成很多python文件,后面代码中会使用这些python文件。

为了验证我们环境或者python文件是否正确,可以运行下面的命令

export PYTHONPATH=$PYTHONPATH:`pwd`:`pwd`/slim

python object_detection/builders/model_builder_test.py

如果出现下面的结果就表示准备好了

......................
----------------------------------------------------------------------
Ran 22 tests in 0.191s

OK

3.准备训练数据

标记数据

使用labelImg工具(下载地址)对数据进行标记,会对每张图片生成一个xml文件。

将标记后的图片分为train和test数据,分别放在train和test文件夹下,xml文件要跟着图片一起。

将xml和图片转成tfrecord

先将xml转为csv格式,可以参考xml_to_csv.py文件,根据自己的情况把PROJECT改成自己的目录地址就好了,然后运行这个python文件,可以得到train_labels.csv和test_labels.csv。

再将图片和csv一起转换为tfrecord文件,可以参考generate_tfrecord.py文件。运行

python generate_tfrecord.py --dataset=boxes

这里的dataset是我对自己的数据集合的命名。另外如果有不同的类别需要在class_text_to_int函数中增加自己的类别,我这边只识别一个类别。

def class_text_to_int(row_label):
	if row_label == 'box':
		return 1
	//if row_label == 'class2':
		//return 2
	//if row_label == 'class3':
		//return 3
	else:
		None

4.开始训练

创建label文件

参考pbtxt 目录下的文件写自己的pbtxt文件,这个文件就是写出自己有多少个类别需要识别,比如两个类别可以这样写

item {
  id: 1
  name: 'car'
}

item {
  id: 2
  name: 'pedestrian'
}

创建config文件

config文件主要是配置一些训练的参数,如果类别,batch size,优化方式,anchor生成的方式等等。

可以参考config参考文件中的congfig文件进行修改,这个目录下有很多网络的config文件,我这边使用了ssd_mobilenet_v2_coco.config作为base进行了修改,主要改动以下几个地方:

model {
  ssd {
    num_classes: 1 ---> number of classes
    box_coder {
...
train_input_reader: {
  tf_record_input_reader {
    input_path: "project_images/boxes/train.tfrecord" ---> train tfrecored path
  }
  label_map_path: "project_images/boxes/labelmap.pbtxt" ---> train label path
}
...

eval_input_reader: {
  tf_record_input_reader {
    input_path: "project_images/boxes/test.tfrecord" ---> test tfrecored path
  }
  label_map_path: "project_images/boxes/labelmap.pbtxt" ---> test label path
  shuffle: false
  num_readers: 1
}

训练

使用legacy/train.py进行训练,只需要指出训练后的数据存放在哪里,以及config文件的路径

python legacy/train.py --logtostderr \
                     --train_dir=project_images/boxes/logs \
                     --pipeline_config_path=project_images/boxes/config/ssd_mobilenet_v2.config

当我们训练一段时间后,在logs目录下会得到一些ckpt文件。

model.ckpt-7386.data-00000-of-00001
model.ckpt-7386.index
model.ckpt-7386.meta

 

生成tflite文件

1.生成frozen pb文件

在生成tflite文件之前,我们需要将ckpt文件的ckpt-data与ckpt-meta一起生成pb文件,就是将计算图与参数的值融合起来生成一个文件。

python object_detection/export_tflite_ssd_graph.py \
		--pipeline_config_path=object_detection/project_images/boxes/config/ssd_mobilenet_v2.config \
		--trained_checkpoint_prefix=object_detection/project_images/boxes/logs/model.ckpt-7386 \
		--output_directory=object_detection/project_images/boxes/logs/tflite \
		--add_postprocessing_op=true

会在tflite目录下生成以下文件:

tflite_graph.pb
tflite_graph.pbtxt

2.生成tflite文件

用生成的tflite_graph.pb文件来生成tflite文件,需要使用到toco工具,如果是第一次使用需要编译toco工具。

下载tensorflow源码:

git clone https://github.com/tensorflow/tensorflow.git

 编译toco:

bazel build tensorflow/python/tools:freeze_graph
bazel build tensorflow/lite/toco:toco

生成tflite文件:

bazel run tensorflow/lite/toco:toco -- --input_file=/local/deeplearning/models/research/object_detection/project_images/boxes/logs/tflite/tflite_graph.pb --output_file=/local/deeplearning/models/research/object_detection/project_images/boxes/logs/tflite/ssd_mobilenetv2.tflite --input_shapes=1,300,300,3 --input_arrays=normalized_input_image_tensor --output_arrays='TFLite_Detection_PostProcess','TFLite_Detection_PostProcess:1','TFLite_Detection_PostProcess:2','TFLite_Detection_PostProcess:3' --inference_type=FLOAT --allow_custom_ops

运行上面的命令后会生成ssd_mobilenetv2.tflite文件。

 

在android中应用tflite文件

可以参考tensorflow中的demo来实现,tensorflow demo

首先将生成的tflite文件拷贝在assets目录下,并且在这个目录下创建一个txt文件,内容就是需要识别的类别

???
boxes

第一行写???表明这一类是background,如果不这样写似乎apk运行会crash。

接着修改BUILD文件,将tflite文件和txt文件都在assets中指明。

android_binary(
    name = "boxes",
    srcs = glob([
        "app/src/main/java/**/*.java",
    ]),
    aapt_version = "aapt",
    # Package assets from assets dir as well as all model targets.
    # Remove undesired models (and corresponding Activities in source)
    # to reduce APK size.
    assets = [
        "//tensorflow/lite/examples/boxes/app/src/main/assets:ssd_mobilenetv2.tflite",
        "//tensorflow/lite/examples/boxes/app/src/main/assets:ssd_mobilenetv2_labels.txt",
    ],
......

还需要在DetectorActivity.java中修改

  private static final String TF_OD_API_MODEL_FILE = "ssd_mobilenetv2.tflite";
  private static final String TF_OD_API_LABELS_FILE = "ssd_mobilenetv2_labels.txt";

最后使用bazel进行编译就可以生成apk文件了

  bazel build -c opt --cxxopt='--std=c++11' --fat_apk_cpu=x86,x86_64,arm64-v8a,armeabi-v7a \
    //tensorflow/lite/examples/android:tflite_demo

生成的apk文件可以根据需求进行物体定位。

猜你喜欢

转载自blog.csdn.net/stesha_chen/article/details/86741474
今日推荐