[NMS] supresión no máxima supresión no máxima y su realización

Los negros son las personas detectadas, los azules son los cuadros de detección y los números azules son los puntajes de probabilidad para cada cuadro.
cuando detectamos un objeto.
Habrá varios marcos de detección. En este momento, necesitamos usar NMS para eliminar marcos redundantes.
Primero elegimos la casilla con la puntuación más alta, en este caso la casilla con una puntuación de 0,8.
En este momento, necesitamos juzgar el valor del iou de la casilla con el puntaje más alto y otra casilla. Supongamos que establecemos el umbral estándar de iou en 0.5
. Si es menor que, no lo elimine.

inserte la descripción de la imagen aquí

Tenga en cuenta que esta operación debe ser para el mismo objeto. Si este es el caso:
inserte la descripción de la imagen aquí
implementemos NMS a continuación:

import torch
from iou import intersection_over_union
def non_max_suppression(
    bboxes,
    iou_threshold,
    prob_threshold,
    box_format="corners"
):
    # bboxes的结构假设是这样的【类别,框的概率,x1,y1,x2,y2】
    assert type(bboxes) == list
    bboxes = [box for box in bboxes if box[1] > prob_threshold]
    bboxes = sorted(bboxes,key= lambda x:x[1],reverse=True)
    bboxes_after_nms = []

    while bboxes:
        chosen_box = bboxes.pop(0)

        bboxes = [
            box for box in bboxes
            if box[0] != chosen_box[0]
            or intersection_over_union(
                torch.tensor(chosen_box[2:]),
                torch.tensor(box[2:])
            )
            < iou_threshold
        ]
        bboxes_after_nms.append(chosen_box)
    return bboxes_after_nms

iou es el iou de mi columna de detección de objetivos

Supongo que te gusta

Origin blog.csdn.net/weixin_54130714/article/details/123612281
Recomendado
Clasificación