pytorch 从头开始faster-rcnn(五):roi

由于能力有限,这部分很多都没看懂,只能把自己看懂的部分写出来。供大家查看其大概想要干什么。

大概步骤为:

1.将前面rpn获得的通过nms极大值抑制获得符合条件的roi及其标签,还有vgg16出来的特征图进行空间金字塔池化。

2.分别输入进roi_loc和roi_scores获得roi坐标和评分。

class VGG16RoIHead(nn.Module):
    """Faster R-CNN Head for VGG-16 based implementation.
    This class is used as a head for Faster R-CNN.
    This outputs class-wise localizations and classification based on feature
    maps in the given RoIs.

    Args:
        n_class (int): The number of classes possibly including the background.
        roi_size (int): Height and width of the feature maps after RoI-pooling.
        spatial_scale (float): Scale of the roi is resized.
        classifier (nn.Module): Two layer Linear ported from vgg16

    """

    def __init__(self, n_class, roi_size, spatial_scale,
                 classifier):
        # n_class includes the background
        super(VGG16RoIHead, self).__init__()

        self.classifier = classifier
        self.cls_loc = nn.Linear(4096, n_class * 4)
        self.score = nn.Linear(4096, n_class)

        normal_init(self.cls_loc, 0, 0.001)
        normal_init(self.score, 0, 0.01)

        self.n_class = n_class
        self.roi_size = roi_size
        self.spatial_scale = spatial_scale
        self.roi = RoIPooling2D(
            self.roi_size, self.roi_size, self.spatial_scale)

    def forward(self, x, rois, roi_indices):
        """Forward the chain.

        We assume that there are :math:`N` batches.

        Args:
            x (Variable): 4D image variable.
            rois (Tensor): A bounding box array containing coordinates of
                proposal boxes.  This is a concatenation of bounding box
                arrays from multiple images in the batch.
                Its shape is :math:`(R', 4)`. Given :math:`R_i` proposed
                RoIs from the :math:`i` th image,
                :math:`R' = \\sum _{i=1} ^ N R_i`.
            roi_indices (Tensor): An array containing indices of images to
                which bounding boxes correspond to. Its shape is :math:`(R',)`.

        """
        # in case roi_indices is  ndarray
        roi_indices = at.totensor(roi_indices).float()
        rois = at.totensor(rois).float()
        # 将roi和roi_indices拼接 形成[yx]的结构,y是roi的标签
        indices_and_rois = t.cat([roi_indices[:, None], rois], dim=1)

        # NOTE: important: yx->xy
        xy_indices_and_rois = indices_and_rois[:, [0, 2, 1, 4, 3]]
        indices_and_rois = xy_indices_and_rois.contiguous()

        # 对特征图进行空间金字塔池化
        pool = self.roi(x, indices_and_rois)
        # 平铺
        pool = pool.view(pool.size(0), -1)

        # 分类层,是截取部分的vgg16分类层,获得多通道的特征图
        fc7 = self.classifier(pool)

        # 获得roi4个坐标和分数
        roi_cls_locs = self.cls_loc(fc7)
        roi_scores = self.score(fc7)
        return roi_cls_locs, roi_scores

能力有限,欢迎指导

猜你喜欢

转载自blog.csdn.net/a362682954/article/details/82928568