【CV】如何使用Tensorflow提供的Object Detection API--4--开始训练模型

版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/u011240016/article/details/83993021

至此已经学习了如何选择预训练模型,将数据集转为TFRecord格式。模型和数据都准备好了,是时候开启训练了。

这些在COCO数据集上的模型都是针对90类进行识别的,如果自己的任务没有这么多类,或者类不同怎么办呢?

如果是我们不是用物体检测的API的话,答案是移除最后的90个类的分类器层,替换为一个新的神经网络层

shape = (fc_2nd_last_get_shape().as_list()[-1],nb_classes)
fc_last_W = tf.Variable(tf.truncated_normal(shape, stddev=1e-2))
fc_last_b = tf.Variable(tf.zeros(nb_classes))
logits = tf.nn.xw_plus_b(fc_2nd_last, fc_last_W, fc_last_b)

但是对于物体检测的API而言,我们只需要修改一下配置文件即可。

object_detection/samples/configs文件夹下,有各种预训练模型的配置文件

以Faster-RCNN举例子:

# Faster R-CNN with Inception Resnet v2, Atrous version, with Cosine
# Learning Rate schedule.
# Trained on COCO, initialized from Imagenet classification checkpoint
# Users should configure the fine_tune_checkpoint field in the train config as
# well as the label_map_path and input_path fields in the train_input_reader and
# eval_input_reader. Search for "PATH_TO_BE_CONFIGURED" to find the fields that
# should be configured.

——TBD—

参考:

https://medium.com/@WuStangDan/step-by-step-tensorflow-object-detection-api-tutorial-part-4-training-the-model-68a9e5d5a333

猜你喜欢

转载自blog.csdn.net/u011240016/article/details/83993021