版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/u011240016/article/details/83993021
至此已经学习了如何选择预训练模型,将数据集转为TFRecord格式。模型和数据都准备好了,是时候开启训练了。
这些在COCO数据集上的模型都是针对90类进行识别的,如果自己的任务没有这么多类,或者类不同怎么办呢?
如果是我们不是用物体检测的API的话,答案是移除最后的90个类的分类器层,替换为一个新的神经网络层。
shape = (fc_2nd_last_get_shape().as_list()[-1],nb_classes)
fc_last_W = tf.Variable(tf.truncated_normal(shape, stddev=1e-2))
fc_last_b = tf.Variable(tf.zeros(nb_classes))
logits = tf.nn.xw_plus_b(fc_2nd_last, fc_last_W, fc_last_b)
但是对于物体检测的API而言,我们只需要修改一下配置文件即可。
在object_detection/samples/configs
文件夹下,有各种预训练模型的配置文件。
以Faster-RCNN举例子:
# Faster R-CNN with Inception Resnet v2, Atrous version, with Cosine
# Learning Rate schedule.
# Trained on COCO, initialized from Imagenet classification checkpoint
# Users should configure the fine_tune_checkpoint field in the train config as
# well as the label_map_path and input_path fields in the train_input_reader and
# eval_input_reader. Search for "PATH_TO_BE_CONFIGURED" to find the fields that
# should be configured.
——TBD—
参考: