FasterRCNN model loss - pytorch implementation

Paper Portal: Faster R-CNN: Towards Real-Time ObjectDetection with Region Proposal Networks , Fast R-CNN

Loss of FasterRCNN:

RPN loss : Calculated according to the scores and regs output by the anchor and RPN, so that the network can better adjust the anchor according to the scores and regs to obtain proposals.
RPN loss consists of classification loss L cls L_{cls}Lclsand regression loss L reg L_{reg}Lregconstitute, L cls L_{cls}LclsCalculated using the BCE loss function, L reg L_{reg}LregCalculated using the smooth L1 loss function.
RPN loss
In the formula, N cls N_{cls}NclsRepresents the number of selected anchor samples, set to 512, N reg N_{reg}Nregis the number of anchor positions, about 2400, λ λλ is taken as 10, so as to simplify the formula and remove the coefficient items of the two losses.
RCNN loss:Calculated according to the scores and regs output by theproposals
RCNN loss consists of classification lossL cls L_{cls}Lclsand regression loss L loc L_{loc}Llocconstitute, L cls L_{cls}LclsCalculated using the CE loss function, L loc L_{loc}LlocCalculated using the smooth L1 loss function.
RCNN loss
where [ u ≥ 1 ] [u ≥ 1][u1 ] for Iverson brackets.
Regression parameter calculation formula:
regression parameters
where,xxx represents the prediction box,xa x_axaIndicates the anchor box, x ∗ x^*x表示ground truth box( y y y, w w w, h h h in the same way).

Training of FasterRCNN:

In the original paper, the author used the alternate training method of RPN and FastRCNN, but in actual training, joint training can be performed.

import torch
import torch.nn.functional as F

from torchvision.models import vgg16

from model import RegionProposalNetwork, Head, FasterRCNN


def cal_iou(box1, box2):
    """
    计算两个box之间的iou
    :param box1: tensor (N1, 4) (xmin, ymin, xmax, ymax)
    :param box2: tensor (N2, 4) (xmin, ymin, xmax, ymax)
    :return: iou (N1, N2) ∈ [0, 1]
    """
    box1_area = torch.prod(box1[:, 2:] - box1[:, :2], dim=-1).unsqueeze(dim=1)
    box2_area = torch.prod(box2[:, 2:] - box2[:, :2], dim=-1)
    box1 = box1.unsqueeze(dim=1)
    xymin = torch.maximum(box1[:, :, :2], box2[:, :2])
    xymax = torch.minimum(box1[:, :, 2:], box2[:, 2:])
    wh = torch.clamp(xymax - xymin, min=0)
    intersection_area = torch.prod(wh, dim=-1)
    return intersection_area / (box1_area + box2_area - intersection_area)


def gtbbox2reg(gt_bbox, bbox):
    """
    将ground truth bbox根据anchor/roi转化为ground truth回归参数
    :param gt_bbox: ground truth bbox (N1, 4) (xmin, ymin, xmax, ymax)
    :param bbox:ground truth回归参数 (N1, 4) (dx, dy, dw, dh)
    :return:
    """
    gtxmin, gtymin, gtxmax, gtymax = map(lambda t: gt_bbox[:, t::4], [0, 1, 2, 3])
    gtx = (gtxmin + gtxmax) / 2
    gty = (gtymin + gtymax) / 2
    gtw = gtxmax - gtxmin
    gth = gtymax - gtymin
    xmin, ymin, xmax, ymax = map(lambda t: bbox[:, t::4], [0, 1, 2, 3])
    x = (xmin + xmax) / 2
    y = (ymin + ymax) / 2
    w = xmax - xmin
    h = ymax - ymin

    dx = (gtx - x) / w
    dy = (gty - y) / h
    dw = torch.log(gtw / w)
    dh = torch.log(gth / h)

    return torch.cat([dx, dy, dw, dh], dim=-1)


class RPNLoss(object):  # RPN LossFunction,针对anchor进行loss计算
    def __init__(
            self,
            num_sample=256,  # 采样数量
            pos_iou_th=0.7,  # 正样本iou阈值(≥)
            neg_iou_th=0.3,  # 负样本iou阈值(<)
            pos_ratio=0.5,  # 正样本占全部样本比例
            cuda=True
    ):
        self.num_sample = num_sample
        self.pos_iou_th = pos_iou_th
        self.neg_iou_th = neg_iou_th
        self.pos_ratio = pos_ratio
        self.cuda = cuda

    def cal_loss(self, scores, regs, anchors, gt_bbox):
        """
        计算RPN损失(一个batch中的一张图像)
        :param scores: tensor(N1, )
        :param regs: tensor(N1, 4)
        :param anchors: tensor(N1, 4) (xmin, ymin, xmax, ymax)
        :param gt_bbox: tensor(N2, 4) (xmin, ymin, xmax, ymax)
        :return: Lcls + Lreg
        """
        anchor_label = torch.ones_like(scores) * (-1)
        iou = cal_iou(anchors, gt_bbox)
        _, gt_anchor_index = torch.max(iou, dim=0)  # (N2)
        anchor_max_iou, anchor_gt_index = torch.max(iou, dim=1)  # (N1)
        anchor_label[anchor_max_iou >= self.pos_iou_th] = 1
        anchor_label[anchor_max_iou < self.neg_iou_th] = 0
        anchor_label[gt_anchor_index] = 1
        pos_index = torch.where(anchor_label == 1)[0]
        neg_index = torch.where(anchor_label == 0)[0]

        max_num_pos = int(self.num_sample * self.pos_ratio)
        if len(pos_index) > max_num_pos:
            pos_index_index = torch.randint(0, len(pos_index), (max_num_pos,))
            pos_index = pos_index[pos_index_index]
        num_pos = len(pos_index)

        num_neg = self.num_sample - num_pos
        if len(neg_index) > num_neg:
            neg_index_index = torch.randint(0, len(neg_index), (num_neg,))
            neg_index = neg_index[neg_index_index]

        pred_scores = torch.cat([scores[pos_index], scores[neg_index]], dim=0)
        gt_scores = torch.cat([torch.ones(num_pos, dtype=torch.float), torch.zeros(num_neg, dtype=torch.float)], dim=0)
        if self.cuda:
            gt_scores = gt_scores.cuda()

        score_loss = (num_pos + num_neg) * F.binary_cross_entropy(pred_scores, gt_scores)

        pred_regs = regs[pos_index]
        pos_anchor = anchors[pos_index]
        pos_gt_index = anchor_gt_index[pos_index]
        pos_gt = gt_bbox[pos_gt_index]
        gt_regs = gtbbox2reg(pos_gt, pos_anchor)

        reg_loss = num_pos * F.smooth_l1_loss(pred_regs, gt_regs)

        return score_loss + reg_loss


class RCNNLoss(object):  # RCNN LossFunction,针对proposal进行loss计算
    def __init__(
            self,
            num_sample=128,  # 采样数量
            pos_iou_th=0.5,  # 正样本iou阈值(≥)
            neg_iou_th=(0.1, 0.5),  # 负样本iou阈值(≥<)
            pos_ratio=0.25,  # 正样本占全部样本比例
            cuda=True
    ):
        self.num_sample = num_sample
        self.pos_iou_th = pos_iou_th
        self.neg_iou_th = neg_iou_th
        self.pos_ratio = pos_ratio
        self.cuda = cuda

    def cal_loss(self, scores, regs, rois, gt_label, gt_bbox):
        """
        计算RCNN损失(一个batch中的一张图像)
        :param scores: tensor(N1, num_classes)
        :param regs: tensor(N1, 4 * num_classes)
        :param rois: tensor(N1, 4) (xmin, ymin, xmax, ymax)
        :param gt_label: tensor(N2, num_classes)
        :param gt_bbox: tensor(N2, 4) (xmin, ymin, xmax, ymax)
        :return: Lcls + Lloc
        """
        roi_label = torch.ones(size=(rois.shape[0],)) * (-1)
        iou = cal_iou(rois, gt_bbox)
        roi_max_iou, roi_gt_index = torch.max(iou, dim=1)  # (N1)
        roi_label[roi_max_iou >= self.pos_iou_th] = 1
        roi_label[(roi_max_iou >= self.neg_iou_th[0]) & (roi_max_iou < self.neg_iou_th[1])] = 0
        pos_index = torch.where(roi_label == 1)[0]
        neg_index = torch.where(roi_label == 0)[0]

        max_num_pos = int(self.num_sample * self.pos_ratio)
        if len(pos_index) > max_num_pos:
            pos_index_index = torch.randint(0, len(pos_index), (max_num_pos,))
            pos_index = pos_index[pos_index_index]
        num_pos = len(pos_index)

        num_neg = self.num_sample - num_pos
        if len(neg_index) > num_neg:
            neg_index_index = torch.randint(0, len(neg_index), (num_neg,))
            neg_index = neg_index[neg_index_index]

        pos_roi = rois[pos_index]
        pos_gt_index = roi_gt_index[pos_index]
        neg_gt_index = roi_gt_index[neg_index]
        pos_gt_label = gt_label[pos_gt_index]
        neg_gt_label = gt_label[neg_gt_index]

        pred_scores = torch.cat([scores[pos_index, :], scores[neg_index, :]], dim=0)
        gt_scores = torch.cat([pos_gt_label, neg_gt_label], dim=0)
        if self.cuda:
            gt_scores = gt_scores.cuda()

        score_loss = (num_pos + num_neg) * F.cross_entropy(pred_scores, gt_scores)

        pred_regs = regs[pos_index].view(num_pos, -1, 4)
        pred_label_regs = torch.stack([pred_regs[i, pos_gt_label[i], :] for i in range(num_pos)], dim=0)

        pos_gt = gt_bbox[pos_gt_index]
        gt_regs = gtbbox2reg(pos_gt, pos_roi)

        loc_loss = num_pos * F.smooth_l1_loss(pred_label_regs, gt_regs)

        return score_loss + loc_loss


if __name__ == "__main__":
    cuda = True
    backbone = vgg16().features  # 选用vgg16的features部分作为FasterRCNN的backbone
    batch_size = 8
    feature_channels = 512  # vgg16输出的特征层通道数
    step = 32  # vgg16输出的特征层与输入图像的步距关系
    num_classes = 20  # 目标类别数量(不包括背景)
    image_size = (800, 1300)  # 输入图像尺寸
    rpn = RegionProposalNetwork(feature_channels, step, image_size, cuda=cuda)  # 构建rpn
    head = Head(num_classes + 1, feature_channels, step)  # 构建head
    fasterrcnn = FasterRCNN(backbone, rpn, head)  # 构建FasterRCNN

    data = torch.randn(batch_size, 3, 800, 1300)  # 模拟网络输入
    if cuda:
        data = data.cuda()
        fasterrcnn.cuda()
    rpn_scores, rpn_regs, anchors, rois, head_scores, head_regs = fasterrcnn(data)
    # torch.Size([8, 9000])
    # torch.Size([8, 9000, 4])
    # torch.Size([9000, 4])
    # torch.Size([8, 2000, 4])
    # torch.Size([8, 2000, 21])
    # torch.Size([8, 2000, 84])
    [print(i.shape) for i in [rpn_scores, rpn_regs, anchors, rois, head_scores, head_regs]]

    rpnloss_fun = RPNLoss(cuda=cuda)  # 构建RPN LossFunction
    rcnnloss_fun = RCNNLoss(cuda=cuda)  # 构建RCNN LossFunction
    for i in range(batch_size):  # 对batch里的图像进行循环
        # 模拟ground truth bbox 和 ground truth label
        gtbbox1 = torch.randint(0, 200, (6, 2))
        gtbbox2 = torch.randint(600, 800, (6, 2))
        gtbbox = torch.cat([gtbbox1, gtbbox2], dim=-1)
        gtlabel = torch.randint(0, num_classes + 2, (6,))
        if cuda:
            gtbbox = gtbbox.cuda()
            gtlabel = gtlabel.cuda()
        one_rpn_scores = rpn_scores[i, :]
        one_rpn_regs = rpn_regs[i, :]
        one_head_scores = head_scores[i, :]
        one_head_regs = head_regs[i, :]
        one_rois = rois[i, :]
        rpnloss = rpnloss_fun.cal_loss(one_rpn_scores, one_rpn_regs, anchors, gtbbox)  # 单张图像的RPN损失
        rcnnloss = rcnnloss_fun.cal_loss(one_head_scores, one_head_regs, one_rois, gtlabel, gtbbox)  # 单张图像的RCNN损失
        print(rpnloss)
        print(rcnnloss)

Guess you like

Origin blog.csdn.net/Peach_____/article/details/128742988