Tensorflow版Faster RCNN源码解析(TFFRCNN) (12) gt_data_layer/layer.py

本blog为github上CharlesShang/TFFRCNN版源码解析系列代码笔记

---------------个人学习笔记---------------

----------------本文作者吴疆--------------

------点击此处链接至博客园原文------

通过find in path发现该代码段函数均未被执行(GtDataLayer在train.py中被注释调用,与caffe modules有关),定义函数与roi_data_layer/layer.py类似

"""
The data layer used during training to train a Fast R-CNN network.
GtDataLayer implements a Caffe Python layer.
"""

定义了一个GtDataLayer类,未见调用,类内定义了8个函数,分别是:

class GtDataLayer(caffe.Layer):
    """Fast R-CNN data layer used for training."""

1._shuffle_roidb_inds(self)

将所有图像rois构成的roidb随机打乱顺序,得到self._perm数组和self._cur起始标志,与roi_data_layer/layer.py中类似,被_get_next_minibatch_inds(...)和set_roidb(...)调用

    def _shuffle_roidb_inds(self):
        """Randomly permute the training roidb."""
        # 随机打乱顺序
        self._perm = np.random.permutation(np.arange(len(self._roidb)))
        # 起始标志
        self._cur = 0

2._get_next_minibatch_inds(self)

获取下一个minibatch的索引(cfg.TRAIN.IMS_PER_BATCH个)并更新self._cur的值,被_get_next_minibatch(...)函数调用

    # 获取下一个minibatch的索引(cfg.TRAIN.IMS_PER_BATCH个)并更新self._cur的值
    def _get_next_minibatch_inds(self):
        """Return the roidb indices for the next minibatch."""
        if self._cur + cfg.TRAIN.IMS_PER_BATCH >= len(self._roidb):
            self._shuffle_roidb_inds()
        db_inds = self._perm[self._cur:self._cur + cfg.TRAIN.IMS_PER_BATCH]
        self._cur += cfg.TRAIN.IMS_PER_BATCH
        """
        # sample images with gt objects
        db_inds = np.zeros((cfg.TRAIN.IMS_PER_BATCH), dtype=np.int32)
        i = 0
        while (i < cfg.TRAIN.IMS_PER_BATCH):
            ind = self._perm[self._cur]
            num_objs = self._roidb[ind]['boxes'].shape[0]
            if num_objs != 0:
                db_inds[i] = ind
                i += 1

            self._cur += 1
            if self._cur >= len(self._roidb):
                self._shuffle_roidb_inds()
        """
        return db_inds

3._get_next_minibatch(self)

获取下一个minibatch作为参数调用minibatch.py中get_minibatch(...)函数,以更新roidb[i]'info_boxes'字段、增加'data'和'parameters'字段组成blobs并返回

    def _get_next_minibatch(self):
        """Return the blobs to be used for the next minibatch."""
        # _get_next_minibatch_inds获取下一个minibatch的索引
        db_inds = self._get_next_minibatch_inds()
        minibatch_db = [self._roidb[i] for i in db_inds]
        # 调用minibatch.py中get_minibatch(...)函数更新roidb[i]'info_boxes'字段、增加'data'和'parameters'字段组成blobs并返回
        return get_minibatch(minibatch_db, self._num_classes)

4.def set_roidb(self, roidb)

初始化roidb,获取self._permself._cur

    # this function is called in training the net
    def set_roidb(self, roidb):
        """Set the roidb to be used by this layer during training."""
        self._roidb = roidb
        self._shuffle_roidb_inds()

5.def setup(self, bottom, top)

对GtDataLayer top的reshape处理

    def setup(self, bottom, top):
        """Setup the GtDataLayer."""
        # parse the layer parameter string, which must be valid YAML
        layer_params = yaml.load(self.param_str_)
        self._num_classes = layer_params['num_classes']
        self._name_to_top_map = {
            'data': 0,
            'info_boxes': 1,
            'parameters': 2}
        # data blob: holds a batch of N images, each with 3 channels
        # The height and width (100 x 100) are dummy仿造的、假的 values
        # 默认TRAIN.SCALES_BASE = (0.25, 0.5, 1.0, 2.0, 3.0)
        num_scale_base = len(cfg.TRAIN.SCALES_BASE)
        # 未知意义???与caffe module相关
        top[0].reshape(num_scale_base, 3, 100, 100)
        # info boxes blob
        top[1].reshape(1, 18)
        # parameters blob
        num_scale = len(cfg.TRAIN.SCALES)
        num_aspect = len(cfg.TRAIN.ASPECTS)
        top[2].reshape(2 + 2*num_scale + 2*num_aspect)

6.def forward(self, bottom, top)

对top赋值操作

    def forward(self, bottom, top):
        """Get blobs and copy them into this layer's top blob vector."""
        blobs = self._get_next_minibatch()
        for blob_name, blob in blobs.iteritems():
            # 该值在setup(...)中被定义
            #         self._name_to_top_map = {
            #             'data': 0,
            #             'info_boxes': 1,
            #             'parameters': 2}
            top_ind = self._name_to_top_map[blob_name]
            # Reshape net's input blobs
            top[top_ind].reshape(*(blob.shape))
            # Copy data into net's input blobs
            top[top_ind].data[...] = blob.astype(np.float32, copy=False)

7.def backward(self, top, propagate_down, bottom)

    def backward(self, top, propagate_down, bottom):
        """This layer does not propagate gradients."""
        pass

8.def reshape(self, bottom, top)

    def reshape(self, bottom, top):
        """Reshaping happens during the call to forward."""
        pass

猜你喜欢

转载自www.cnblogs.com/deeplearning1314/p/11325011.html
今日推荐