Pytorch implements IOU loss

Recently, I was working on a target detection project. I thought that Pytorch would have its own iou loss, but I couldn't find it after searching around, so I implemented it myself:

Students who need it

def iou_loss(predicted_boxes, target_boxes):
    """
    计算IoU损失
    Args:
        predicted_boxes: 预测的bbox坐标,形状为(N, 4),其中N是batch size。
        target_boxes: 真实的bbox坐标,形状为(N, 4),其中N是batch size。
    Returns:
        iou_loss: IoU损失,形状为(N,)。
    """
    # 计算预测框和真实框的左上角和右下角坐标
    pred_x1, pred_y1, pred_x2, pred_y2 = predicted_boxes[:, 0], predicted_boxes[:, 1], predicted_boxes[:,
                                                                                       2], predicted_boxes[:, 3]
    true_x1, true_y1, true_x2, true_y2 = target_boxes[:, 0], target_boxes[:, 1], target_boxes[:, 2], target_boxes[:, 3]

    # 计算交集和并集的左上角和右下角坐标
    xi1 = torch.max(pred_x1, true_x1)
    yi1 = torch.max(pred_y1, true_y1)
    xi2 = torch.min(pred_x2, true_x2)
    yi2 = torch.min(pred_y2, true_y2)

    # 计算交集和并集的面积
    inter_area = torch.clamp(xi2 - xi1, min=0) * torch.clamp(yi2 - yi1, min=0)
    pred_area = (pred_x2 - pred_x1) * (pred_y2 - pred_y1)
    true_area = (true_x2 - true_x1) * (true_y2 - true_y1)
    union_area = pred_area + true_area - inter_area

    # 计算IoU
    iou = inter_area / union_area

    # 计算IoU损失
    iou_loss = 1.0 - iou

    return iou_loss

It should be noted that the loss of this code is based on the coordinate values ​​of the upper left corner and the lower right corner of the box:

So friends should also pay attention when using it, the incoming data format

Guess you like

Origin blog.csdn.net/weixin_53374931/article/details/130441242