「解析」YOLOv3 - NMS算法

YOLO系列的NMS算法大致相同,本文介绍的 NMS算法 是基于 YOLOv3 实现的,根据YOLOv3架构图所示,test过程将所有的预测框拼接成一个张量进行输出预测,Prediction:[batch, num_anchor, 85] ,其中 85 的构成 [x, y, w, h, confidence, classes]

在这里插入图片描述

在这里插入图片描述

在这里插入图片描述

在这里插入图片描述

for img_i, img_pred in prediction:
	img_pred = img_pred [confidence] >= conf_thres			# [10467,85] -> [2709,85]
	
	score = confidence * max(classes)	# 2709 score 得分
	img_pred.sorted()					# 根据score降序排列
	class_confs, class_preds = image_pred[:, 5:].max(1, keepdim=True)	# 获得当前pred的最大类别概率、对应类别
	
	detections = [x, y, w, h, confidence, class_confs, class_preds]

	for 
		计算最大score 与其他pred的 bbox_iou,保留超过 nms_thres的 preds
		
		根据 class 信息 与 nsm_thres 判断是否进行抑制
	

enumerate() 是python的内置函数,用于将一个可遍历的数据对象(如列表、元组或字符串)组合为一个索引序列,同时列出数据和数据下标,一般用在 for 循环当中。
返回 enumerate(枚举对象) 下标号 、成员。

采用的COCO数据集,80个类别 +


NMS 完整代码

def non_max_suppression(prediction, conf_thres=0.5, nms_thres=0.4):
    """
    Removes detections with lower object confidence score than 'conf_thres' and performs
    Non-Maximum Suppression to further filter detections.
    Returns detections with shape:
        (x1, y1, x2, y2, object_conf, class_score, class_pred)
    """

    # From (center x, center y, width, height) to (x1, y1, x2, y2)
    prediction[..., :4] = xywh2xyxy(prediction[..., :4])
    output = [None for _ in range(len(prediction))]
    for image_i, image_pred in enumerate(prediction):
        # Filter out confidence scores below threshold
        image_pred = image_pred[image_pred[:, 4] >= conf_thres]
        # If none are remaining => process next image
        if not image_pred.size(0):
            continue
        
        score = image_pred[:, 4] * image_pred[:, 5:].max(1)[0]		# Object confidence times class confidence
        
        # Sort by it
        image_pred = image_pred[(-score).argsort()]
        class_confs, class_preds = image_pred[:, 5:].max(1, keepdim=True)
        detections = torch.cat((image_pred[:, :5], class_confs.float(), class_preds.float()), 1)
        # Perform non-maximum suppression
        keep_boxes = []
        while detections.size(0):
            large_overlap = bbox_iou(detections[0, :4].unsqueeze(0), detections[:, :4]) > nms_thres
            label_match = detections[0, -1] == detections[:, -1]
            # Indices of boxes with lower confidence scores, large IOUs and matching labels
            invalid = large_overlap & label_match
            weights = detections[invalid, 4:5]
            # Merge overlapping bboxes by order of confidence
            detections[0, :4] = (weights * detections[invalid, :4]).sum(0) / weights.sum()
            keep_boxes += [detections[0]]
            detections = detections[~invalid]
        if keep_boxes:
            output[image_i] = torch.stack(keep_boxes)

    return output

在这里插入图片描述

a = torch.randn(4, 4)
a
tensor([[ 0.0785,  1.5267, -0.8521,  0.4065],
        [ 0.1598,  0.0788, -0.0745, -1.2700],
        [ 1.2208,  1.0722, -0.7064,  1.2564],
        [ 0.0669, -0.2318, -0.8229, -0.9280]])

torch.argsort(a, dim=1)
tensor([[2, 0, 3, 1],
        [3, 2, 1, 0],
        [2, 1, 0, 3],
        [3, 2, 1, 0]])

在这里插入图片描述

>>> a = torch.randn(4, 4)
>>> a
tensor([[-1.2360, -0.2942, -0.1222,  0.8475],
        [ 1.1949, -1.1127, -2.2379, -0.6702],
        [ 1.5717, -0.9207,  0.1297, -1.8768],
        [-0.6172,  1.0036, -0.6060, -0.2432]])
>>> torch.max(a, 1)
torch.return_types.max(values=tensor([0.8475, 1.1949, 1.5717, 1.0036]), indices=tensor([3, 0, 0, 1]))

Guess you like

Origin blog.csdn.net/ViatorSun/article/details/129815435