目标检测算法SSD-detector(pytorch实现)

目标检测算法SSD-detector(pytorch实现)

一批图像输入SSD深度神经网络之后,输出的是预测的bboxes的偏移量和每一个类别的置信度,需要对这些输出进行处理才能得到最终的预测结果。

对每一张图片的预测值分别进行如下处理:

  • 对每一个bbox取其分数最高(经过softmax)的一类作为其类别(注意要先把背景的分数剔除)。
  • 取分数排名前k个bbox进行解码。
  • 对这些bbox进行多类别nms处理(非极大性抑制)

多类别非极大性抑制:

  • 考虑到非极大性抑制后,bbox最高得分的类别可能会被剔除,所以需要考虑其它类别的情况,因此可以通过得分阈值的方式,将超过得分阈值的类别的bbox都选取出来作非极大抑制(这意味着最终bbox预测的类别结果不一定是得分最高的那个类别,因为得分最高的类别已经被剔除了)。
  • 多类别非极大性抑制,可以先按类别将bboxes平移到一个类间互不干扰的空间,然后再进行非极大性抑制即可。

详细代码:

import torch
def detector(priors_cxcy, predicted_locs, predicted_scores, min_score, max_overlap, top_k):
    '''
    Params:
        priors_cxcy: [8732, 4]
        predicted_locs: [N, 8732, 4]
        predicted_scores: [N, 8732, num_classes]
    '''
    result_list = []
    for locs, scores in zip(predicted_locs, predicted_scores):              #对每一张图片的预测值分别进行处理
        result = get_bboxes_single(priors_cxcy, locs, scores, min_score, max_overlap, top_k)
        result_list.append(result)

    return result_list    

def get_bboxes_single(anchors, predicted_locs, predicted_scores, min_score, max_overlap, top_k):
    '''
    Params:
        anchors: [8732, 4]
        predicted_locs: [8732, 4]
        predicted_scores: [8732, num_classes]
    '''
    assert anchors.size(0) == predicted_locs.size(0) == predicted_scores.size(0)

    scores = predicted_scores.softmax(-1)
    max_scores, _ = scores[:, 1:].max(dim=1)            #把背景类别的分数去掉,这里的0代表背景
    _, topk_inds = max_scores.topk(top_k)

    anchors = anchors[topk_inds, :]
    predicted_locs = predicted_locs[topk_inds, :]
    scores = scores[topk_inds, :]

    bboxes = cxcy_to_xy(gcxgcy_to_cxcy(predicted_locs, anchors))       #decode

    det_bboxes, det_labels = multiclass_nms(bboxes, scores, min_score, max_overlap)
    det_labels = det_labels + 1                         #因为之前把背景类别去掉了,要把类别+1才是真正的类别
    return det_bboxes, det_labels


def multiclass_nms(multi_bboxes, multi_scores, score_thr, threshold, max_num=-1):
    '''
    Params:
        multi_bboxes: [n, 4]
        multi_scores: [n, num_class]
    '''
    num_classes = multi_scores.size(1) - 1
    bboxes = multi_bboxes[:, None].expand(-1, num_classes, 4)
    scores = multi_scores[:, 1:]
    
    valid_mask = scores > score_thr
    bboxes = bboxes[valid_mask]
    scores = scores[valid_mask]
    labels = valid_mask.nonzero()[:, 1]

    if bboxes.numel() == 0:
        bboxes = multi_bboxes.new_zeros((0, 5))
        labels = multi_bboxes.new_zeros((0, ), dtype=torch.long)
        return bboxes, labels
    
    dets, keep = batch_nms(bboxes, scores, labels, threshold)
    if max_num > 0:
        dets = dets[:max_num]
        keep = keep[:max_num]
    
    return dets, labels[keep]


def batch_nms(bboxes, scores, inds, threshold):
    '''
    Params:
        bboxes: [n, 4]
        scores: [n]
    '''
    #将不同类别的预测框平移到互不干扰的区域,这样才能在同类预测框之间进行非极大性抑制
    max_coordinate = bboxes.max()
    offset = inds.to(bboxes) * (max_coordinate + 1)
    bboxes_for_nms = bboxes + offset[:, None]
    dets, keep = nms(torch.cat([bboxes_for_nms, scores[:, None]], -1), threshold)
    bboxes = bboxes[keep]
    scores = dets[:, -1]
    return torch.cat([bboxes, scores[:, None]], -1), keep

def nms(dets, threshold):
    x1 = dets[:, 0]
    y1 = dets[:, 1]
    x2 = dets[:, 2]
    y2 = dets[:, 3]

    areas = (y2-y1) * (x2-x1)
    scores = dets[:, 4]
    keep = []

    _, order = scores.sort(0, descending=True)

    while order.numel() > 0:
        if order.numel() == 1:
            i = order.item()
            keep.append(i)
            break
        else:
            i = order[0].item()
            keep.append(i)
            
        xx1 = x1[order[1:]].clamp(min=x1[i].data)
        yy1 = y1[order[1:]].clamp(min=y1[i].data)
        xx2 = x2[order[1:]].clamp(max=x2[i].data)
        yy2 = y2[order[1:]].clamp(max=y2[i].data)
        inter = (xx2-xx1).clamp(min=0) * (yy2-yy1).clamp(min=0)

        iou = inter / (areas[i]+areas[order[1:]] - inter)
        idx = (iou <= threshold).nonzero().squeeze()
        if idx.numel() == 0:
            break
        order = order[idx+1]
    
    keep = torch.LongTensor(keep)
    return dets[keep, :], keep 

猜你喜欢

转载自blog.csdn.net/qq_38600065/article/details/107333024