【pytorch】Mask-RCNN官方源码剖析(Ⅱ)

./maskrcnn_benchmark/structures/

定义了检测模式下包含的数据结构:

  • bounding_box.py
    定义了class BoxList(object) 类,该类用于表示一系列的bounding boxes。这些boxes会以 N * 4大小的tensor来表示。为了唯一确定boxes在图片中的准确位置,该类还保存了图片的维度,另外也可以添加额外的信息到特定的bounding box中,如标签信息。
import torch

# transpose
FLIP_LEFT_RIGHT = 0
FLIP_TOP_BOTTOM = 1


class BoxList(object):
    """
    This class represents a set of bounding boxes.
    The bounding boxes are represented as a Nx4 Tensor.
    In order to uniquely determine the bounding boxes with respect
    to an image, we also store the corresponding image dimensions.
    They can contain extra information that is specific to each bounding box, such as
    labels.
    """

    def __init__(self, bbox, image_size, mode="xyxy"):
        # bbox(tensor): n x 4,代表n个box,如:[[0,0,10,10],[0,0,5,5]]
        # image size:(width,height)
        # 根据bbox的数据类型获取对应的device
        device = bbox.device if isinstance(bbox, torch.Tensor) else torch.device("cpu")
        # 将bbox 转换成 tensor 类型
        bbox = torch.as_tensor(bbox, dtype=torch.float32, device=device)
        # bbox 的维度数量必须为2,并且第二维必须为 4,即 shape=(n ,4), 代表n个box
        if bbox.ndimension() != 2:
            raise ValueError(
                "bbox should have 2 dimensions, got {}".format(bbox.ndimension())
            )
        if bbox.size(-1) != 4:
            raise ValueError(
                "last dimension of bbox should have a "
                "size of 4, got {}".format(bbox.size(-1))
            )
        # 只支持以下两种模式
        if mode not in ("xyxy", "xywh"):
            raise ValueError("mode should be 'xyxy' or 'xywh'")
         
        # 为成员变量赋值 
        self.bbox = bbox
        self.size = image_size  # (image_width, image_height)
        self.mode = mode
        self.extra_fields = {
    
    } # 以字典结构存储额外信息

    # 添加新的键值或覆盖旧的键值
    def add_field(self, field, field_data):
        self.extra_fields[field] = field_data
    # 获取指定键对应的值
    def get_field(self, field):
        return self.extra_fields[field]
    # 判断额外信息中是否存在该键
    def has_field(self, field):
        return field in self.extra_fields
    # 以列表的形式返回所有键的名称
    def fields(self):
        return list(self.extra_fields.keys())
    # 将另一个boxlist类型的额外信息(字典)复制到额外信息(extra_fields)中
    def _copy_extra_fields(self, bbox):
        for k, v in bbox.extra_fields.items():
            self.extra_fields[k] = v

# 将当前的 bbox 的表示形式转换成参数指定的模式
    def convert(self, mode):
     # 只支持以下两种模式
        if mode not in ("xyxy", "xywh"):
            raise ValueError("mode should be 'xyxy' or 'xywh'")
        if mode == self.mode:
            return self
        # we only have two modes, so don't need to check
        # self.mode
        # 调用成员函数,将坐标转化成(x1,y1,x2,y2)的形式
        xmin, ymin, xmax, ymax = self._split_into_xyxy()
        if mode == "xyxy":
        # 如模式为“xyxy”,则直接将xmin,ymin,xmax,ymax合并成 n x 4的bbox
            bbox = torch.cat((xmin, ymin, xmax, ymax), dim=-1)
            # 创建一个新的boxlist实例
            bbox = BoxList(bbox, self.size, mode=mode)
        else:
        # 否则就将 xmin,ymin,xmax,ymax转换成(x,y,w,h)后再连接在一起
            TO_REMOVE = 1
            bbox = torch.cat(
                (xmin, ymin, xmax - xmin + TO_REMOVE, ymax - ymin + TO_REMOVE), dim=-1
            )
            bbox = BoxList(bbox, self.size, mode=mode)
        # 复制当前实例的 extra_fields 信息到这个新创建的实例当中,并将这个新的实例返回
        bbox._copy_extra_fields(self)
        return bbox

# 获取bbox的(x1,y1,x2,y2)形式的坐标表示
    def _split_into_xyxy(self):
        if self.mode == "xyxy":
        # x,y 的shape为 n x 1,代表着n个box的x,y坐标
            xmin, ymin, xmax, ymax = self.bbox.split(1, dim=-1)
            return xmin, ymin, xmax, ymax
        elif self.mode == "xywh":
            TO_REMOVE = 1
            xmin, ymin, w, h = self.bbox.split(1, dim=-1)
            return (
                xmin,
                ymin,
                xmin + (w - TO_REMOVE).clamp(min=0),
                ymin + (h - TO_REMOVE).clamp(min=0),
            )
        else:
            raise RuntimeError("Should not be here")

# 将所有的 boxes 按照给定的 size 和图片的尺寸进行缩放,创建一个副本存储缩放后的boxes并返回
    def resize(self, size, *args, **kwargs):
        """
        Returns a resized copy of this bounding box
        :param size: The requested size in pixels, as a 2-tuple:
            (width, height).
            size是指定放缩后的大小
        """
        # 计算宽和高的放缩比例(new_size/old_size)
        ratios = tuple(float(s) / float(s_orig) for s, s_orig in zip(size, self.size))
        # 如果宽高放缩比例相同
        if ratios[0] == ratios[1]:
            ratio = ratios[0]
            # 令所有的 bbox 都乘以放缩比例,不论 bbox 是以 xyxy 形式还是以 xywh 表示
            # 乘以系数就可以正确的将 bbox的坐标转换到放缩后图片的对应坐标
            scaled_box = self.bbox * ratio
            bbox = BoxList(scaled_box, size, mode=self.mode)
            # bbox._copy_extra_fields(self)
            for k, v in self.extra_fields.items():
                if not isinstance(v, torch.Tensor):
                    v = v.resize(size, *args, **kwargs)
                bbox.add_field(k, v)
            return bbox
       # 宽和高的放缩比例不同,因此需要拆分后分别放缩然后连接在一起
        ratio_width, ratio_height = ratios
        # 获取 bbox 的左上角和右下角坐标
        xmin, ymin, xmax, ymax = self._split_into_xyxy()
        # 分别对宽(xmax,xmin)和高(ymax,ymin)进行放缩
        scaled_xmin = xmin * ratio_width
        scaled_xmax = xmax * ratio_width
        scaled_ymin = ymin * ratio_height
        scaled_ymax = ymax * ratio_height
        # 将左上角和右下角的坐标连接起来,组合缩放后的 bbox 表示
        scaled_box = torch.cat(
            (scaled_xmin, scaled_ymin, scaled_xmax, scaled_ymax), dim=-1
        )
        bbox = BoxList(scaled_box, size, mode="xyxy")
        # bbox._copy_extra_fields(self)
        for k, v in self.extra_fields.items():
            if not isinstance(v, torch.Tensor):
                v = v.resize(size, *args, **kwargs)
            bbox.add_field(k, v)
        # 将bbox 转换成指定的格式(因为前面强制转换成了 xyxy模式了)
        return bbox.convert(self.mode)

    def transpose(self, method):
        """
        对bbox 进行转换(翻转或者旋转90度)
        method(int)此处只能为 0 或者 1,目前仅仅支持这两个转换方法
        FLIP_LEFT_RIGHT = 0 , FLIP_TOP_BOTTOM = 1
        Transpose bounding box (flip or rotate in 90 degree steps)
        :param method: One of :py:attr:`PIL.Image.FLIP_LEFT_RIGHT`,
          :py:attr:`PIL.Image.FLIP_TOP_BOTTOM`, :py:attr:`PIL.Image.ROTATE_90`,
          :py:attr:`PIL.Image.ROTATE_180`, :py:attr:`PIL.Image.ROTATE_270`,
          :py:attr:`PIL.Image.TRANSPOSE` or :py:attr:`PIL.Image.TRANSVERSE`.
        """
        if method not in (FLIP_LEFT_RIGHT, FLIP_TOP_BOTTOM):
            raise NotImplementedError(
                "Only FLIP_LEFT_RIGHT and FLIP_TOP_BOTTOM implemented"
            )

        # 获取图片的宽和高
        image_width, image_height = self.size
        xmin, ymin, xmax, ymax = self._split_into_xyxy()
        if method == FLIP_LEFT_RIGHT:
            TO_REMOVE = 1
            transposed_xmin = image_width - xmax - TO_REMOVE
            transposed_xmax = image_width - xmin - TO_REMOVE
            transposed_ymin = ymin
            transposed_ymax = ymax
        elif method == FLIP_TOP_BOTTOM:
            transposed_xmin = xmin
            transposed_xmax = xmax
            transposed_ymin = image_height - ymax
            transposed_ymax = image_height - ymin

        transposed_boxes = torch.cat(
            (transposed_xmin, transposed_ymin, transposed_xmax, transposed_ymax), dim=-1
        )
        # 根据转换后的 boxes 坐标创建一个新的 BoxList 实例, 同时将 extra_fields 信息复制
        bbox = BoxList(transposed_boxes, self.size, mode="xyxy")
        # bbox._copy_extra_fields(self)
        for k, v in self.extra_fields.items():
            if not isinstance(v, torch.Tensor):
                v = v.transpose(method)
            bbox.add_field(k, v)
        # 将 bbox 的 mode 转换后返回
        return bbox.convert(self.mode)

    def crop(self, box):
        """
        Crops a rectangular region from this bounding box. The box is a
        4-tuple defining the left, upper, right, and lower pixel
        coordinate.
        box 是一个4元组,指定了希望裁剪的区域的左上角和右下角
        """
        # 获取当前所有 boxes 的最左、最上、最下、最右的坐标
        xmin, ymin, xmax, ymax = self._split_into_xyxy()
        # 获取欲裁剪的 box 的宽和高
        w, h = box[2] - box[0], box[3] - box[1]
        # 根据 box 指定的区域,对现有的 proposals boxes 进行裁剪
        # 即改变其坐标位置,如果发现有超出规定尺寸的情况,则将其截断
        cropped_xmin = (xmin - box[0]).clamp(min=0, max=w)
        cropped_ymin = (ymin - box[1]).clamp(min=0, max=h)
        cropped_xmax = (xmax - box[0]).clamp(min=0, max=w)
        cropped_ymax = (ymax - box[1]).clamp(min=0, max=h)

        # TODO should I filter empty boxes here?
        if False:
            is_empty = (cropped_xmin == cropped_xmax) | (cropped_ymin == cropped_ymax)

        cropped_box = torch.cat(
            (cropped_xmin, cropped_ymin, cropped_xmax, cropped_ymax), dim=-1
        )
        bbox = BoxList(cropped_box, (w, h), mode="xyxy")
        # bbox._copy_extra_fields(self)
        for k, v in self.extra_fields.items():
            if not isinstance(v, torch.Tensor):
                v = v.crop(box)
            bbox.add_field(k, v)
        return bbox.convert(self.mode)

    # Tensor-like methods

    def to(self, device):
    # 设备转移函数
    # device:“cuda:x” or “cpu”
    # 将当前的bbox移动到指定的device上,并且创建一个新的 boxlist 实例
        bbox = BoxList(self.bbox.to(device), self.size, self.mode)
        for k, v in self.extra_fields.items():
            if hasattr(v, "to"):
                v = v.to(device)
            bbox.add_field(k, v)
        return bbox

    def __getitem__(self, item):
        bbox = BoxList(self.bbox[item], self.size, self.mode)
        for k, v in self.extra_fields.items():
            bbox.add_field(k, v[item])
        return bbox

    def __len__(self):
        return self.bbox.shape[0]

    def clip_to_image(self, remove_empty=True):
    # 该函数将 bbox 的坐标限制在 image 的尺寸内
        TO_REMOVE = 1
        self.bbox[:, 0].clamp_(min=0, max=self.size[0] - TO_REMOVE)
        self.bbox[:, 1].clamp_(min=0, max=self.size[1] - TO_REMOVE)
        self.bbox[:, 2].clamp_(min=0, max=self.size[0] - TO_REMOVE)
        self.bbox[:, 3].clamp_(min=0, max=self.size[1] - TO_REMOVE)
        if remove_empty:
            box = self.bbox
            keep = (box[:, 3] > box[:, 1]) & (box[:, 2] > box[:, 0])
            return self[keep]
        return self

    def area(self):
    # 获取区域的面积
        box = self.bbox
        if self.mode == "xyxy":
            TO_REMOVE = 1
            area = (box[:, 2] - box[:, 0] + TO_REMOVE) * (box[:, 3] - box[:, 1] + TO_REMOVE)
        elif self.mode == "xywh":
            area = box[:, 2] * box[:, 3]
        else:
            raise RuntimeError("Should not be here")

        return area

    def copy_with_fields(self, fields, skip_missing=False):
    # 深度复制函数
        bbox = BoxList(self.bbox, self.size, self.mode)
        if not isinstance(fields, (list, tuple)):
            fields = [fields]
        for field in fields:
            if self.has_field(field):
                bbox.add_field(field, self.get_field(field))
            elif not skip_missing:
                raise KeyError("Field '{}' not found in {}".format(field, self))
        return bbox

    def __repr__(self):
        s = self.__class__.__name__ + "("
        s += "num_boxes={}, ".format(len(self))
        s += "image_width={}, ".format(self.size[0])
        s += "image_height={}, ".format(self.size[1])
        s += "mode={})".format(self.mode)
        return s


if __name__ == "__main__":
    bbox = BoxList([[0, 0, 10, 10], [0, 0, 5, 5]], (10, 10))
    s_bbox = bbox.resize((5, 5))
    print(s_bbox)
    print(s_bbox.bbox)

    t_bbox = bbox.transpose(0)
    print(t_bbox)
    print(t_bbox.bbox)
  • boxlist_ops.py:
import torch

from .bounding_box import BoxList

from maskrcnn_benchmark.layers import nms as _box_nms

# 会对一个boxlist类型数据中的box执行非极大抑制算法
def boxlist_nms(boxlist, nms_thresh, max_proposals=-1, score_field="scores"):
    """
    Performs non-maximum suppression on a boxlist, with scores specified
    in a boxlist field via score_field.
    Arguments:
        boxlist(BoxList)
        nms_thresh (float)
        max_proposals (int): if > 0, then only the top max_proposals are kept
            after non-maximum suppression
        score_field (str)
    """
    if nms_thresh <= 0:
        return boxlist
    mode = boxlist.mode # 缓存当前的模式
    boxlist = boxlist.convert("xyxy") # 转换成指定模式
    boxes = boxlist.bbox # 获取 n*4 的bbox 列表
    score = boxlist.get_field(score_field) # 获取对应的 score 列表
    # 调用_box_nms 执行非极大值抑制
    keep = _box_nms(boxes, score, nms_thresh)
    if max_proposals > 0:
        keep = keep[: max_proposals]
    boxlist = boxlist[keep]
    return boxlist.convert(mode)


def remove_small_boxes(boxlist, min_size):
    """
    使得 boxlist 只保留那些尺寸大于一定值的box
    Only keep boxes with both sides >= min_size
    Arguments:
        boxlist (Boxlist)
        min_size (int)
    """
    # TODO maybe add an API for querying the ws / hs
    xywh_boxes = boxlist.convert("xywh").bbox
    _, _, ws, hs = xywh_boxes.unbind(dim=1)
    keep = (
        (ws >= min_size) & (hs >= min_size)
    ).nonzero().squeeze(1)
    return boxlist[keep]


# implementation from https://github.com/kuangliu/torchcv/blob/master/torchcv/utils/box.py
# with slight modifications
def boxlist_iou(boxlist1, boxlist2):
    """Compute the intersection over union of two set of boxes.
    The box order must be (xmin, ymin, xmax, ymax).
    Arguments:
      box1: (BoxList) bounding boxes, sized [N,4].
      box2: (BoxList) bounding boxes, sized [M,4].
    Returns:
      (tensor) iou, sized [N,M].
    Reference:
      https://github.com/chainer/chainercv/blob/master/chainercv/utils/bbox/bbox_iou.py
      计算并交比的
    """
    if boxlist1.size != boxlist2.size:
        raise RuntimeError(
                "boxlists should have same image size, got {}, {}".format(boxlist1, boxlist2))
    boxlist1 = boxlist1.convert("xyxy")
    boxlist2 = boxlist2.convert("xyxy")
    N = len(boxlist1)
    M = len(boxlist2)

    area1 = boxlist1.area()
    area2 = boxlist2.area()

    box1, box2 = boxlist1.bbox, boxlist2.bbox

    lt = torch.max(box1[:, None, :2], box2[:, :2])  # [N,M,2]
    rb = torch.min(box1[:, None, 2:], box2[:, 2:])  # [N,M,2]

    TO_REMOVE = 1

    wh = (rb - lt + TO_REMOVE).clamp(min=0)  # [N,M,2]
    inter = wh[:, :, 0] * wh[:, :, 1]  # [N,M]

    iou = inter / (area1[:, None] + area2 - inter)
    return iou

# 该函数会将一个组成元素为 boxlist 的列表合并成一个 boxlist 对象。
def cat_boxlist(bboxes):
    """
    Concatenates a list of BoxList (having the same image size) into a
    single BoxList
    Arguments:
        bboxes (list[BoxList])
    """
    # 确保类型为列表或者元组,且其中的元素类型为 boxlist
    assert isinstance(bboxes, (list, tuple))
    assert all(isinstance(bbox, BoxList) for bbox in bboxes)

    # 确保所有的 boxlist 的size、mode以及extra_fields 字典的 keys 是相同的
    size = bboxes[0].size
    assert all(bbox.size == size for bbox in bboxes)

    mode = bboxes[0].mode
    assert all(bbox.mode == mode for bbox in bboxes)

    fields = set(bboxes[0].fields())
    assert all(set(bbox.fields()) == fields for bbox in bboxes)

    # 调用本文件的 _cat() 方法,将 bboxes 里面的 boxlist 数据连接成一个 boxlist
    cat_boxes = BoxList(_cat([bbox.bbox for bbox in bboxes], dim=0), size, mode)

    # 将各个 boxlist 的 fields 补充上
    for field in fields:
        data = _cat([bbox.get_field(field) for bbox in bboxes], dim=0)
        cat_boxes.add_field(field, data)

    return cat_boxes

猜你喜欢

转载自blog.csdn.net/qq_43348528/article/details/107535805
今日推荐