pytorch实现yolo-v3 (源码阅读和复现) -- 004算法分析

对上一步模型直接检测层的预测结果进行进一步过滤, 核心还是nms

1.核心代码

def write_results(predictions, confidence, num_class, nms=True, nms_thresh=0.4):
    # 保留预测结果中置信度大于给定阈值的部分
    # confidence: shape=(1,10647, 85)
    # mask: shape=(1,10647) => 增加一维度之后 (1, 10647, 1)
    mask = (predictions[:, :, 4] > confidence).float().unsqueeze(2)
    predictions = predictions*mask # 小于置信度的条目值全为0, 剩下部分不变

    # 如果没有检测任何有效目标,返回值为0
    ind_nz = torch.nonzero(predictions[:, :, 4].squeeze()).squeeze()
    if ind_nz.size(0) == 0:
        return 0
    # predictions = predictions[:, ind_nz, :]

    '''
    保留预测结果中置信度大于阈值的bbox
    下面开始为nms准备
    '''

    # prediction的前五个数据分别表示 (Cx, Cy, w, h, score)
    bbox = predictions.new(predictions.shape)
    bbox[:, :, 0] = (predictions[:, :, 0] - predictions[:, :, 2]/2) # x1 = Cx - w/2
    bbox[:, :, 1] = (predictions[:, :, 1] - predictions[:, :, 3]/2) # y1 = Cy - h/2
    bbox[:, :, 2] = (predictions[:, :, 0] + predictions[:, :, 2]/2) # x2 = Cx + w/2
    bbox[:, :, 3] = (predictions[:, :, 1] + predictions[:, :, 3]/2) # y2 = Cy + h/2
    predictions[:, :, :4] = bbox[:, :, :4] # 计算后的新坐标复制回去

    batch_size = predictions.size(0) # dim=0
    # output = predictions.new(1, predictions.size(2)+1) # shape=(1,85+1)

    write = False # 拼接结果到output中最后返回
    for ind in range(batch_size):
        # 选择此batch中第ind个图像的预测结果
        prediction = predictions[ind]
        # 结果过滤
        ind_nz = torch.nonzero(prediction[:, 4].squeeze()).squeeze()
        if ind_nz.size(0) == 0:
            continue
        prediction = prediction[ind_nz, :]
        # print(prediction.shape) # shape=(10647->14, 85)

        # 最大值, 最大值索引, 按照dim=1 方向计算
        max_score, max_score_ind = torch.max(prediction[:, 5:], 1) # prediction[:, 5:]表示每一分类的分数
        # 维度扩展
        # max_score: shape=(10647->14) => (10647->14,1)
        max_score = max_score.float().unsqueeze(1)
        max_score_ind = max_score_ind.float().unsqueeze(1)
        seq = (prediction[:, :5], max_score, max_score_ind) # 取前五
        prediction = torch.cat(seq, 1) # shape=(10647, 5+1+1=7)
        # print(prediction.shape)


        # 获取当前图像检测结果中出现的所有类别
        try:
            image_classes = unique(prediction[:, -1]) # tensor, shape=(n)
        except:
            continue

        # 执行classwise nms
        for cls in image_classes:
            # 分离检测结果中属于当前类的数据
            # -1: cls_index, -2: score
            class_mask = (prediction[:, -1] == cls) # shape=(n)
            class_mask_ind = torch.nonzero(class_mask).squeeze() # shape=(n,1) => (n)

            # prediction_: shape(n,7)
            prediction_class = prediction[class_mask_ind].view(-1, 7) # 从prediction中取出属于cls类别的所有结果,为下一步的nms的输入

            ''' 到此步 prediction_class 已经存在了我们需要进行非极大值抑制的数据 '''
            # 开始 nms
            # 按照score排序, 由大到小
            # 最大值最上面
            score_sort_ind = torch.sort(prediction_class[:, 4], descending=True)[1] # [0] 排序结果, [1]排序索引
            prediction_class = prediction_class[score_sort_ind]
            cnt = prediction_class.size(0) # 个数

            '''开始执行 "非极大值抑制" 操作'''
            if nms:
                for i in range(cnt):
                    # 对已经有序的结果,每次开始更新后索引加一,挨个与后面的结果比较
                    try:
                        ious = bbox_iou(prediction_class[i].unsqueeze(0), prediction_class[i+1:])
                    except ValueError:
                        break
                    except IndexError:
                        break

                    # 计算出需要移除的item
                    iou_mask = (ious < nms_thresh).float().unsqueeze(1)
                    prediction_class[i+1:] *= iou_mask # 保留i自身
                    # 开始移除
                    non_zero_ind = torch.nonzero(prediction_class[:, 4].squeeze())
                    prediction_class = prediction_class[non_zero_ind].view(-1, 7)

                    # iou_mask = (ious < nms_thresh).float() # shape=(n)
                    # non_zero_ind = torch.nonzero(iou_mask).squeeze()+1 # 会为空,导致出错
                    # prediction_class = prediction_class[non_zero_ind].view(-1, 7)

            # 当前类的nms执行完之后,保存结果
            batch_ind = prediction_class.new(prediction_class.size(0), 1).fill_(ind)
            seq = batch_ind, prediction_class

            if  not write:
                output = torch.cat(seq, 1)
                write = True
            else:
                out = torch.cat(seq, 1)
                output = torch.cat((output, out))

    return output

2. 算法分析

第一步显示对目标区域置信度低于阈值的目标(低于阈值认为是bg)剔除掉, 后面的结果在进行nms过滤

在做nms之前, 对bbox坐标进行了变换, 从( Cx, Cy, w,h)变为(x1,y1, x2,y2),这样方便计算 iou

预测结果是一个batch,包含了n>=1张图像, 开始循环{
    取出当前图像的预测结果, 
    过滤掉当前一张图像中置信度低于阈值的结果
    统计当前预测结果包含的分类(先做排序,由大到小), 循环{

        取出当前图像预测结果中属于当前类的预测结果, 
        对当前类执行nms,循环
        {
            取当前第i项和后面[i+1:]分别计算iou, 统计重叠区域大于阈值的部分,剔除掉, 
            更新预测结果,知道索引越界
        }
        此时的预测结果中保留的值已经是有效的了,放入到output返回值中之前,需要对其在扩展以为,放入信息:所属图像索引
    }
}

猜你喜欢

转载自blog.csdn.net/u010472607/article/details/81564629
今日推荐