py-faster-rcnn中demo.py代码与C++版本的代码对比: part07 nms, 获取符合条件的boxes

这里”C++版本的代码”是指: https://github.com/galian123/cpp_faster_rcnn_detect .

py-faster-rcnn中demo.py代码, 是指 https://github.com/rbgirshick/py-faster-rcnn/blob/master/tools/demo.py 以及
https://github.com/rbgirshick/py-faster-rcnn/tree/master/lib 目录下的一些代码.

涉及到的.py文件都是 https://github.com/rbgirshick/py-faster-rcnn/ 中的.

回到demo.py中继续im_detect()之后的处理.

★ 显示检测结果

♦ python代码

demo.py

    # Visualize detections for each class
    CONF_THRESH = 0.8
    NMS_THRESH = 0.3
    for cls_ind, cls in enumerate(CLASSES[1:]): # 见"解释1"
        cls_ind += 1 # because we skipped background
        # 取出所有行中, cls类别的box位置: 4*cls_ind:4*(cls_ind + 1)
        cls_boxes = boxes[:, 4*cls_ind:4*(cls_ind + 1)] # cls_boxes的shape为(300, 4)
        # 所有行中, cls_ind列的分数
        cls_scores = scores[:, cls_ind]     # cls_scores的shape为(300, )
        # cls_scores[:, np.newaxis]的shape为(300,1)
        dets = np.hstack((cls_boxes,        # 见"解释2"
                          cls_scores[:, np.newaxis])).astype(np.float32)
        keep = nms(dets, NMS_THRESH)    # 见"解释3"
        dets = dets[keep, :]
        vis_detections(im, cls, dets, thresh=CONF_THRESH) # 见"解释4"
  • 解释1:
    CLASSES中有21类, CLASSES[1:]是为了滤掉背景.
    cls_ind是从0开始的, 对应’aeroplane’, cls_ind+1是与实际数据(boxes)的位置对应的.
CLASSES = ('__background__',
           'aeroplane', 'bicycle', 'bird', 'boat',
           'bottle', 'bus', 'car', 'cat', 'chair',
           'cow', 'diningtable', 'dog', 'horse',
           'motorbike', 'person', 'pottedplant',
           'sheep', 'sofa', 'train', 'tvmonitor')
  • 解释2: hstack

https://docs.scipy.org/doc/numpy/reference/generated/numpy.hstack.html#numpy.hstack
举例:

>>> a = np.array((1,2,3))
>>> b = np.array((2,3,4))
>>> np.hstack((a,b))
array([1, 2, 3, 2, 3, 4])

>>> a = np.array([[1],[2],[3]])
>>> b = np.array([[2],[3],[4]])
>>> np.hstack((a,b))
array([[1, 2],
       [2, 3],
       [3, 4]])

所以, dets = np.hstack((cls_boxes, cls_scores[:, np.newaxis]))是将分数附加到cls_boxes的每一行中.
所以dets的shape是(300,5)

  • 解释3: keep = nms(dets, NMS_THRESH)

./lib/fast_rcnn/nms_wrapper.py

from fast_rcnn.config import cfg
# 这里的nms.gpu_nms是 lib/nms/gpu_nms.so
from nms.gpu_nms import gpu_nms
from nms.cpu_nms import cpu_nms

def nms(dets, thresh, force_cpu=False):
    """Dispatch to either CPU or GPU NMS implementations."""

    if dets.shape[0] == 0:
        return []
    if cfg.USE_GPU_NMS and not force_cpu:
        # run here, 调用的是gpu_nms.so中的函数
        return gpu_nms(dets, thresh, device_id=cfg.GPU_ID)
    else:
        return cpu_nms(dets, thresh)
  • 解释4: vis_detections(im, cls, dets, thresh=CONF_THRESH)
def vis_detections(im, class_name, dets, ax, thresh=0.5):
    # 找出符合要求的boxes: 大于thresh
    inds = np.where(dets[:, -1] >= thresh)[0]

♦ C++代码

python中clip_boxesdets = np.hstack((cls_boxes, cls_scores[:, np.newaxis]))对应C++中bbox_transform_inv()中的如下代码:

    pred[(j * num + i) * 5 + 0] = max(min(pred_ctr_x - 0.5 * pred_w, img_width - 1), 0);
    pred[(j * num + i) * 5 + 1] = max(min(pred_ctr_y - 0.5 * pred_h, img_height - 1), 0);
    pred[(j * num + i) * 5 + 2] = max(min(pred_ctr_x + 0.5 * pred_w, img_width - 1), 0);
    pred[(j * num + i) * 5 + 3] = max(min(pred_ctr_y + 0.5 * pred_h, img_height - 1), 0);
    // scores的大小是(num * CLASS_NUM), i*CLASS_NUM+j表示每一行中每一项的分数
    pred[(j * num + i) * 5 + 4] = scores[i * CLASS_NUM + j];
    boxes = new float[num*4];
    pred = new float[num*5*CLASS_NUM];
    pred_per_class = new float[num*5];
    sorted_pred_cls = new float[num*5];
    keep = new int[num];

   for (int i = 1; i < CLASS_NUM; i++) {// 处理每一种类别
        for (int j = 0; j< num; j++){// num为300
            for (int k=0; k<5; k++){// 5个值:(xmin,ymin,xmax,ymax,score)
                pred_per_class[j*5+k] = pred[(i*num+j)*5+k];
            }
        }
        boxes_sort(num, pred_per_class, sorted_pred_cls);
        _nms(keep, &num_out, sorted_pred_cls, num, 5, NMS_THRESH, 0);
        for(int i_ = 0; sorted_pred_cls[keep[i_]*5+4] > CONF_THRESH && i_ < num_out; ++i_){
            vector<float> bbox;
            bbox.push_back(sorted_pred_cls[keep[i_]*5+0]);
            bbox.push_back(sorted_pred_cls[keep[i_]*5+1]);
            bbox.push_back(sorted_pred_cls[keep[i_]*5+2] - sorted_pred_cls[keep[i_]*5+0]);
            bbox.push_back(sorted_pred_cls[keep[i_]*5+3] - sorted_pred_cls[keep[i_]*5+1]);
            bbox.push_back(i); // class type
            bbox.push_back(sorted_pred_cls[keep[i_]*5+4]); // score
            bboxes.push_back(bbox);
        }
    }

————– 分割线 ————–
本系列文章如下:

猜你喜欢

转载自blog.csdn.net/u013553529/article/details/79029394
NMS