Faster R-CNN源码阅读之十一:Faster R-CNN预测demo代码补完

版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/DaVinciL/article/details/81983670
  1. Faster R-CNN源码阅读之零:写在前面
  2. Faster R-CNN源码阅读之一:Faster R-CNN/lib/networks/network.py
  3. Faster R-CNN源码阅读之二:Faster R-CNN/lib/networks/factory.py
  4. Faster R-CNN源码阅读之三:Faster R-CNN/lib/networks/VGGnet_test.py
  5. Faster R-CNN源码阅读之四:Faster R-CNN/lib/rpn_msr/generate_anchors.py
  6. Faster R-CNN源码阅读之五:Faster R-CNN/lib/rpn_msr/proposal_layer_tf.py
  7. Faster R-CNN源码阅读之六:Faster R-CNN/lib/fast_rcnn/bbox_transform.py
  8. Faster R-CNN源码阅读之七:Faster R-CNN/lib/rpn_msr/anchor_target_layer_tf.py
  9. Faster R-CNN源码阅读之八:Faster R-CNN/lib/rpn_msr/proposal_target_layer_tf.py
  10. Faster R-CNN源码阅读之九:Faster R-CNN/tools/train_net.py
  11. Faster R-CNN源码阅读之十:Faster R-CNN/lib/fast_rcnn/train.py
  12. Faster R-CNN源码阅读之十一:Faster R-CNN预测demo代码补完
  13. Faster R-CNN源码阅读之十二:写在最后

一、介绍
   本demo由Faster R-CNN官方提供,我只是在官方的代码上增加了注释,一方面方便我自己学习,另一方面贴出来和大家一起交流。
   这里对之前使用Faster R-CNN的demo进行预测时候的代码进行补完。

二、代码和注释
文件目录:Faster-RCNN/lib/fast_rcnn/test.py

def im_detect(sess, net, im, boxes=None):
    """Detect object classes in an image given object proposals.
    Arguments:
        net (caffe.Net): Fast R-CNN network to use
        im (ndarray): color image to test (in BGR order)
        boxes (ndarray): R x 4 array of object proposals
    Returns:
        scores (ndarray): R x K array of object class scores (K includes
            background as object category 0)
        boxes (ndarray): R x (4*K) array of predicted bounding boxes

    对输入的图片(像素信息)进行目标检测。
    :param sess: tensorflow会话
    :param net: Faster RCNN网络
    :param im: 图片像素信息
    :param boxes: R × 4的数组,表示物体的proposals
    :returns: scores(ndarray),shape为R × K的二维数组,目标类别的检测分数,K中包括背景,
              boxes(ndarray),shape为R × 4K的二维数组,目标检测的预测bbox。
    """
    #
    blobs, im_scales = _get_blobs(im, boxes)

    # When mapping from image ROIs to feature map ROIs, there's some aliasing
    # (some distinct image ROIs get mapped to the same feature ROI).
    # Here, we identify duplicate feature ROIs, so we only compute features
    # on the unique subset.
    # cfg.TEST.HAS_RPN默认为True,表示使用RPN网络,因此下方的if代码块不会被执行,这里忽略注释。下同。
    if cfg.DEDUP_BOXES > 0 and not cfg.TEST.HAS_RPN:
        v = np.array([1, 1e3, 1e6, 1e9, 1e12])
        hashes = np.round(blobs['rois'] * cfg.DEDUP_BOXES).dot(v)
        _, index, inv_index = np.unique(hashes, return_index=True,
                                        return_inverse=True)
        blobs['rois'] = blobs['rois'][index, :]
        boxes = boxes[index, :]

    # 使用RPN时
    if cfg.TEST.HAS_RPN:
        # 取出图片的blob数据
        im_blob = blobs['data']
        # 获取图片的尺寸信息,这里可以看出,获取的是经过缩放之后的图片尺寸信息。
        blobs['im_info'] = np.array([[im_blob.shape[1], im_blob.shape[2], im_scales[0]]], dtype=np.float32)
    # forward pass
    # 准备feed进网络的数据
    if cfg.TEST.HAS_RPN:
        feed_dict = {net.data: blobs['data'], net.im_info: blobs['im_info'], net.keep_prob: 1.0}
    else:
        feed_dict = {net.data: blobs['data'], net.rois: blobs['rois'], net.keep_prob: 1.0}

    run_options = None
    run_metadata = None
    # cfg.TEST.DEBUG_TIMELINE(以及cfg.TRAIN.DEBUG_TIMELINE)默认为False。修改为True后可能会产生调用库的问题。下同。
    # Couldn't open CUDA library libcupti.so.9.1. LD_LIBRARY_PATH: :/usr/local/cuda/lib64
    # 故最好不要修改此参数!!
    if cfg.TEST.DEBUG_TIMELINE:
        run_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE)
        run_metadata = tf.RunMetadata()

    # 运行网络,获得所需要的结果,cls_score,cls_prob,bbox_pred,rois
    cls_score, cls_prob, bbox_pred, rois = sess.run(
        [net.get_output('cls_score'), net.get_output('cls_prob'), net.get_output('bbox_pred'), net.get_output('rois')],
        feed_dict=feed_dict,
        options=run_options,
        run_metadata=run_metadata)

    if cfg.TEST.HAS_RPN:
        # 确保每次只运行一张图片
        assert len(im_scales) == 1, "Only single-image batch implemented"
        # 对产生的rois进行缩放,缩放过后的rois正好可以对应于原始图像上的rois。
        boxes = rois[:, 1:5] / im_scales[0]

    # 是否使用SVM进行预测,默认为不使用即False。
    if cfg.TEST.SVM:
        # use the raw scores before softmax under the assumption they
        # were trained as linear SVMs
        scores = cls_score
    else:
        # use softmax estimated probabilities
        scores = cls_prob

    # 默认cfg.TEST.BBOX_REG为True
    if cfg.TEST.BBOX_REG:
        # Apply bounding-box regression deltas
        # 应用bbox回归的deltas
        box_deltas = bbox_pred
        # 计算pred boxes,具体计算过程请参考bbox_transform_inv函数。
        pred_boxes = bbox_transform_inv(boxes, box_deltas)
        # 对边界框进行裁剪,保证边界都在图片的范围内部。
        pred_boxes = _clip_boxes(pred_boxes, im.shape)
    else:
        # Simply repeat the boxes, once for each class
        # 仅仅简单地重复boxes,每个类别一次。
        pred_boxes = np.tile(boxes, (1, scores.shape[1]))

    #
    if cfg.DEDUP_BOXES > 0 and not cfg.TEST.HAS_RPN:
        # Map scores and predictions back to the original set of boxes
        scores = scores[inv_index, :]
        pred_boxes = pred_boxes[inv_index, :]

    if cfg.TEST.DEBUG_TIMELINE:
        trace = timeline.Timeline(step_stats=run_metadata.step_stats)
        trace_file = open(str(long(time.time() * 1000)) + '-test-timeline.ctf.json', 'w')
        trace_file.write(trace.generate_chrome_trace_format(show_memory=False))
        trace_file.close()

    # 返回计算的scores和预测框的位置
    return scores, pred_boxes

文件目录:Faster-RCNN/lib/fast_rcnn/test.py

def _get_blobs(im, rois):
    """
    Convert an image and RoIs within that image into network inputs.
    将im和图片内部的rois转换成网络的输入。
    :param im: 图片的像素矩阵
    :param rois: rois
    :returns:
    """
    # 由于这里是Faster RCNN,cfg.TEST.HAS_RPN默认为True,不使用RPN的代码此处忽略注释。
    if cfg.TEST.HAS_RPN:
        # 定义一个字典
        blobs = {'data': None, 'rois': None}
        # 存储blob数据块和缩放系数。
        blobs['data'], im_scale_factors = _get_image_blob(im)
    else:
        blobs = {'data': None, 'rois': None}
        blobs['data'], im_scale_factors = _get_image_blob(im)
        # 多尺度图像(图像金字塔)
        if cfg.IS_MULTISCALE:
            if cfg.IS_EXTRAPOLATING:
                blobs['rois'] = _get_rois_blob(rois, cfg.TEST.SCALES)
            else:
                blobs['rois'] = _get_rois_blob(rois, cfg.TEST.SCALES_BASE)
        else:
            blobs['rois'] = _get_rois_blob(rois, cfg.TEST.SCALES_BASE)

    # 返回blob数据和缩放系数
    return blobs, im_scale_factors

文件目录:Faster-RCNN/lib/fast_rcnn/test.py

def _get_image_blob(im):
    """Converts an image into a network input.
    Arguments:
        im (ndarray): a color image in BGR order
    Returns:
        blob (ndarray): a data blob holding an image pyramid
        im_scale_factors (list): list of image scales (relative to im) used
            in the image pyramid

    将图片的像素矩阵转换成网络输入。
    :param im: 图片的像素矩阵
    :returns:
    """
    # 将原始的像素矩阵复制一份
    im_orig = im.astype(np.float32, copy=True)
    # 图片像素的归一化,这里采用了减去各个通道设定的像素均值的方法。
    im_orig -= cfg.PIXEL_MEANS

    # 图片的shape
    im_shape = im_orig.shape
    # 获取图片尺寸(高度,宽度)的短边和长边
    im_size_min = np.min(im_shape[0:2])
    im_size_max = np.max(im_shape[0:2])

    # 保存所放过的图片的list
    processed_ims = []
    # 保存缩放的系数
    im_scale_factors = []

    # 对每个缩放的尺寸, cfg.TEST.SCALES一般取值[600]。
    for target_size in cfg.TEST.SCALES:
        # 计算对短边的缩放系数
        im_scale = float(target_size) / float(im_size_min)
        # Prevent the biggest axis from being more than MAX_SIZE
        # 为了防止缩放之后长边过长(超过设定的cfg.TEST.MAX_SIZE, 该值一般取值1000),
        # 如果过长,则im scale主要表示针对长边的缩放,将长边缩放到可接受的最大长度。
        if np.round(im_scale * im_size_max) > cfg.TEST.MAX_SIZE:
            im_scale = float(cfg.TEST.MAX_SIZE) / float(im_size_max)
        # 对图片进行缩放
        im = cv2.resize(im_orig, None, None, fx=im_scale, fy=im_scale,
                        interpolation=cv2.INTER_LINEAR)
        # 将图片的缩放系数和缩放之后的图片分别保存到list中。
        im_scale_factors.append(im_scale)
        processed_ims.append(im)

    # Create a blob to hold the input images
    # 对处理之后的图片进行处理,产生blob
    blob = im_list_to_blob(processed_ims)

    # 返回产生的blob和缩放系数的list
    return blob, np.array(im_scale_factors)

文件目录:Faster-RCNN/lib/utils/blob.py

def im_list_to_blob(ims):
    """Convert a list of images into a network input.

    Assumes images are already prepared (means subtracted, BGR order, ...).
    将包含若干图片像素信息的list转换成blob数据块。这里的处理仅仅只是将所有的图片进行左上角的对齐。
    :param ims: 一个list,里面包含若干个图片的像素信息。
    :return: 处理之后的blob数据块。
    """
    # 返回各个维度的最大长度,这里真真有用的是最大的高度和宽度。
    max_shape = np.array([im.shape for im in ims]).max(axis=0)
    # 获取图片的总数目
    num_images = len(ims)
    # 根据图片总数目,最大高度宽度等信息,生成一个全0numpy数组,用以将图片的左上角对齐。
    blob = np.zeros((num_images, max_shape[0], max_shape[1], 3), dtype=np.float32)
    # 对每个图片
    for i in xrange(num_images):
        im = ims[i]
        # 进行赋值操作,这样的复制过程正好从blob数组的左上角开始。
        blob[i, 0:im.shape[0], 0:im.shape[1], :] = im

    # 返回
    return blob

文件目录:Faster-RCNN/lib/utils/timer.py

# coding=utf-8
# --------------------------------------------------------
# Fast R-CNN
# Copyright (c) 2015 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# Written by Ross Girshick
# --------------------------------------------------------

import time
# 一个简单的计时器
class Timer(object):
    """A simple timer."""
    def __init__(self):
        # 总时间
        self.total_time = 0.
        # 被调用的次数
        self.calls = 0
        # 开始时间
        self.start_time = 0.
        # 结束时间和开始时间之间的时间差
        self.diff = 0.
        # 平均时间
        self.average_time = 0.

    def tic(self):
        '''
        记录开始时间
        :return: None
        '''
        # using time.time instead of time.clock because time time.clock
        # does not normalize for multithreading
        self.start_time = time.time()

    def toc(self, average=True):
        '''
        结束计时
        :param average: True或者False,为True时返回平均时间,否则返回时间差
        :return: 平均时间或者时间差
        '''
        # 记录结束时距离开始的时间差值
        self.diff = time.time() - self.start_time
        # 将时间差值加到总时间上,并把调用次数加1
        self.total_time += self.diff
        self.calls += 1
        # 重新计算平均时间
        self.average_time = self.total_time / self.calls
        # 根据average返回
        if average:
            return self.average_time
        else:
            return self.diff

猜你喜欢

转载自blog.csdn.net/DaVinciL/article/details/81983670