easy-Fpnソースコードの解釈(3):bbox

Easy-Fpnソースコードの解釈(2つ):bbox

Bboxコード分析

import torch
from torch import Tensor


class BBox(object):
    # 初始化方法:四个坐标点
    def __init__(self, left: float, top: float, right: float, bottom: float):
        super().__init__()
        self.left = left
        self.top = top
        self.right = right
        self.bottom = bottom
     
    # 输出实例化对象的坐标信息
    def __repr__(self) -> str:
        return 'BBox[l={:.1f}, t={:.1f}, r={:.1f}, b={:.1f}]'.format(
            self.left, self.top, self.right, self.bottom)
    
    # 将bbox转换成列表
    def tolist(self):
        return [self.left, self.top, self.right, self.bottom]

    @staticmethod
    # 类调用,静态方法:把(x0,y0,x1,y1)转换成(x,y,w,h)
    def to_center_base(bboxes: Tensor):
        return torch.stack([
            (bboxes[:, 0] + bboxes[:, 2]) / 2,
            (bboxes[:, 1] + bboxes[:, 3]) / 2,
            bboxes[:, 2] - bboxes[:, 0],
            bboxes[:, 3] - bboxes[:, 1]
        ], dim=1)

    @staticmethod
    # 把(x,y,w,h)转换成(x0,y0,x1,y1)
    def from_center_base(center_based_bboxes: Tensor) -> Tensor:
        return torch.stack([
            center_based_bboxes[:, 0] - center_based_bboxes[:, 2] / 2,
            center_based_bboxes[:, 1] - center_based_bboxes[:, 3] / 2,
            center_based_bboxes[:, 0] + center_based_bboxes[:, 2] / 2,
            center_based_bboxes[:, 1] + center_based_bboxes[:, 3] / 2
        ], dim=1)

    @staticmethod
    # 计算偏移量:计算原bbox的(x,y,w,h)和目标bbox的(x,y,w,h)的偏移量
    def calc_transformer(src_bboxes: Tensor, dst_bboxes: Tensor) -> Tensor:
        center_based_src_bboxes = BBox.to_center_base(src_bboxes)
        center_based_dst_bboxes = BBox.to_center_base(dst_bboxes)
        transformers = torch.stack([
            (center_based_dst_bboxes[:, 0] - center_based_src_bboxes[:, 0]) / center_based_src_bboxes[:, 2],
            (center_based_dst_bboxes[:, 1] - center_based_src_bboxes[:, 1]) / center_based_src_bboxes[:, 3],
            torch.log(center_based_dst_bboxes[:, 2] / center_based_src_bboxes[:, 2]),
            torch.log(center_based_dst_bboxes[:, 3] / center_based_src_bboxes[:, 3])
        ], dim=1)
        return transformers

    @staticmethod
    # 使用anchor和偏移量计算出proposal
    def apply_transformer(src_bboxes: Tensor, transformers: Tensor) -> Tensor:
        center_based_src_bboxes = BBox.to_center_base(src_bboxes)
        center_based_dst_bboxes = torch.stack([
            transformers[:, 0] * center_based_src_bboxes[:, 2] + center_based_src_bboxes[:, 0],
            transformers[:, 1] * center_based_src_bboxes[:, 3] + center_based_src_bboxes[:, 1],
            torch.exp(transformers[:, 2]) * center_based_src_bboxes[:, 2],
            torch.exp(transformers[:, 3]) * center_based_src_bboxes[:, 3]
        ], dim=1)
        dst_bboxes = BBox.from_center_base(center_based_dst_bboxes)
        return dst_bboxes

    @staticmethod
    # 静态方法,计算交叠度
    def iou(source: Tensor, other: Tensor) -> Tensor:
        source = source.repeat(other.shape[0], 1, 1).permute(1, 0, 2)
        other = other.repeat(source.shape[0], 1, 1)

        source_area = (source[:, :, 2] - source[:, :, 0]) * (source[:, :, 3] - source[:, :, 1])
        other_area = (other[:, :, 2] - other[:, :, 0]) * (other[:, :, 3] - other[:, :, 1])

        intersection_left = torch.max(source[:, :, 0], other[:, :, 0]) # 结果是二维张量
        intersection_top = torch.max(source[:, :, 1], other[:, :, 1])  # 结果是二维张量
        intersection_right = torch.min(source[:, :, 2], other[:, :, 2]) # 结果是二维张量
        intersection_bottom = torch.min(source[:, :, 3], other[:, :, 3]) # 结果是二维张量
        # torch.clamp的操作如下
        #       | min, if x_i < min
        # y_i = | x_i, if min <= x_i <= max
        #       | max, if x_i > max
        intersection_width = torch.clamp(intersection_right - intersection_left, min=0)
        intersection_height = torch.clamp(intersection_bottom - intersection_top, min=0)
        # 对应位置的元素相乘,不改变矩阵形状
        intersection_area = intersection_width * intersection_height

        return intersection_area / (source_area + other_area - intersection_area)

    @staticmethod
    def inside(source: Tensor, other: Tensor) -> bool:
        # torch.repeat:https://blog.csdn.net/qq_29695701/article/details/89763168
        source = source.repeat(other.shape[0], 1, 1).permute(1, 0, 2)
        other = other.repeat(source.shape[0], 1, 1)
        # 用模拟数据进行上面两步得到的结果如下面的例子一所示
        
        # 这一步就是找出四个坐标点都在图像边界内的anchor
        return ((source[:, :, 0] >= other[:, :, 0]) * (source[:, :, 1] >= other[:, :, 1]) *
                (source[:, :, 2] <= other[:, :, 2]) * (source[:, :, 3] <= other[:, :, 3]))
        # import torch
        # a = [[2, 3, 4]]
        # b = [[[5, 6, 7],[4, 5, 6]]]

        # a = torch.tensor(a)
        # b = torch.tensor(b)
        # b = a.repeat(2, 1, 1)
        # a = a.repeat(b.shape[0], 1, 1)
        # b = b.repeat(a.shape[0], 1, 1).permute(1,0,2)
        # print(a)
        # print(a.size())
        # print(b)
        # print(b.size())
        # print(a[:,:,0]<b[:,:,0])
        # >> tensor([[[2, 3, 4]]])
        # >> torch.Size([1, 1, 3])
        # >> tensor([[[5, 6, 7]],
        #            [[4, 5, 6]]])
        # >> torch.Size([2, 1, 3])
        # tensor([[1],
        #         [1]], dtype=torch.uint8)

    @staticmethod
    # 夹紧bbox,将所有bbox的数值都夹紧在min和max之间
    def clip(bboxes: Tensor, left: float, top: float, right: float, bottom: float) -> Tensor:
        return torch.stack([
            torch.clamp(bboxes[:, 0], min=left, max=right),
            torch.clamp(bboxes[:, 1], min=top, max=bottom),
            torch.clamp(bboxes[:, 2], min=left, max=right),
            torch.clamp(bboxes[:, 3], min=top, max=bottom)
        ], dim=1)

例子一:
a = torch.Tensor([[3, 2, 6, 9]])
b = torch.Tensor([[3, 5, 6, 7], [3, 5, 6, 7]])
c = []
d = [[0,0,3,6]]
c.append(a)
c.append(b)
c = torch.cat(c, dim=0)
d = torch.Tensor(d)
print(c.shape)
c = c.repeat(d.shape[0], 1, 1).permute(1, 0, 2)
d = d.repeat(c.shape[0], 1, 1)
print(c.shape)
print(d.shape)

>> torch.Size([3, 4])
>> torch.Size([3, 1, 4])
>> torch.Size([3, 1, 4])

おすすめ

転載: blog.csdn.net/ThunderF/article/details/104733793