由于能力有限,这部分很多都没看懂,只能把自己看懂的部分写出来。供大家查看其大概想要干什么。
大概步骤为:
1.将前面rpn获得的通过nms极大值抑制获得符合条件的roi及其标签,还有vgg16出来的特征图进行空间金字塔池化。
2.分别输入进roi_loc和roi_scores获得roi坐标和评分。
class VGG16RoIHead(nn.Module):
"""Faster R-CNN Head for VGG-16 based implementation.
This class is used as a head for Faster R-CNN.
This outputs class-wise localizations and classification based on feature
maps in the given RoIs.
Args:
n_class (int): The number of classes possibly including the background.
roi_size (int): Height and width of the feature maps after RoI-pooling.
spatial_scale (float): Scale of the roi is resized.
classifier (nn.Module): Two layer Linear ported from vgg16
"""
def __init__(self, n_class, roi_size, spatial_scale,
classifier):
# n_class includes the background
super(VGG16RoIHead, self).__init__()
self.classifier = classifier
self.cls_loc = nn.Linear(4096, n_class * 4)
self.score = nn.Linear(4096, n_class)
normal_init(self.cls_loc, 0, 0.001)
normal_init(self.score, 0, 0.01)
self.n_class = n_class
self.roi_size = roi_size
self.spatial_scale = spatial_scale
self.roi = RoIPooling2D(
self.roi_size, self.roi_size, self.spatial_scale)
def forward(self, x, rois, roi_indices):
"""Forward the chain.
We assume that there are :math:`N` batches.
Args:
x (Variable): 4D image variable.
rois (Tensor): A bounding box array containing coordinates of
proposal boxes. This is a concatenation of bounding box
arrays from multiple images in the batch.
Its shape is :math:`(R', 4)`. Given :math:`R_i` proposed
RoIs from the :math:`i` th image,
:math:`R' = \\sum _{i=1} ^ N R_i`.
roi_indices (Tensor): An array containing indices of images to
which bounding boxes correspond to. Its shape is :math:`(R',)`.
"""
# in case roi_indices is ndarray
roi_indices = at.totensor(roi_indices).float()
rois = at.totensor(rois).float()
# 将roi和roi_indices拼接 形成[yx]的结构,y是roi的标签
indices_and_rois = t.cat([roi_indices[:, None], rois], dim=1)
# NOTE: important: yx->xy
xy_indices_and_rois = indices_and_rois[:, [0, 2, 1, 4, 3]]
indices_and_rois = xy_indices_and_rois.contiguous()
# 对特征图进行空间金字塔池化
pool = self.roi(x, indices_and_rois)
# 平铺
pool = pool.view(pool.size(0), -1)
# 分类层,是截取部分的vgg16分类层,获得多通道的特征图
fc7 = self.classifier(pool)
# 获得roi4个坐标和分数
roi_cls_locs = self.cls_loc(fc7)
roi_scores = self.score(fc7)
return roi_cls_locs, roi_scores
能力有限,欢迎指导