pytorch 实现IOU损失

最近在做目标检测的项目,本来以为Pytorch会自带iou损失,结果找了一圈没找到,于是自己实现了一下:

需要的同学自取

def iou_loss(predicted_boxes, target_boxes):
    """
    计算IoU损失
    Args:
        predicted_boxes: 预测的bbox坐标,形状为(N, 4),其中N是batch size。
        target_boxes: 真实的bbox坐标,形状为(N, 4),其中N是batch size。
    Returns:
        iou_loss: IoU损失,形状为(N,)。
    """
    # 计算预测框和真实框的左上角和右下角坐标
    pred_x1, pred_y1, pred_x2, pred_y2 = predicted_boxes[:, 0], predicted_boxes[:, 1], predicted_boxes[:,
                                                                                       2], predicted_boxes[:, 3]
    true_x1, true_y1, true_x2, true_y2 = target_boxes[:, 0], target_boxes[:, 1], target_boxes[:, 2], target_boxes[:, 3]

    # 计算交集和并集的左上角和右下角坐标
    xi1 = torch.max(pred_x1, true_x1)
    yi1 = torch.max(pred_y1, true_y1)
    xi2 = torch.min(pred_x2, true_x2)
    yi2 = torch.min(pred_y2, true_y2)

    # 计算交集和并集的面积
    inter_area = torch.clamp(xi2 - xi1, min=0) * torch.clamp(yi2 - yi1, min=0)
    pred_area = (pred_x2 - pred_x1) * (pred_y2 - pred_y1)
    true_area = (true_x2 - true_x1) * (true_y2 - true_y1)
    union_area = pred_area + true_area - inter_area

    # 计算IoU
    iou = inter_area / union_area

    # 计算IoU损失
    iou_loss = 1.0 - iou

    return iou_loss

需要注意的是,本代码的损失是基于框的左上角和右下角两个点的坐标值的:

所以小伙伴在使用的时候也要注意,传入的数据格式

猜你喜欢

转载自blog.csdn.net/weixin_53374931/article/details/130441242