# --------------------------------------------------------
# Fast R-CNN
# Copyright (c) 2015 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# Written by Ross Girshick and Xinlei Chen
# --------------------------------------------------------
"""The data layer used during training to train a Fast R-CNN network.
RoIDataLayer implements a Caffe Python layer.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from model.config import cfg
from roi_data_layer.minibatch import get_minibatch
import numpy as np
import time
class RoIDataLayer(object):
"""Fast R-CNN data layer used for training."""
def __init__(self, roidb, num_classes, random=False):
"""Set the roidb to be used by this layer during training."""
self._roidb = roidb
self._num_classes = num_classes
# Also set a random flag
self._random = random
self._shuffle_roidb_inds() # 随机排列roidb
def _shuffle_roidb_inds(self):
"""Randomly permute the training roidb."""
# If the random flag is set,
# then the database is shuffled according to system time
# Useful for the validation set
if self._random: # 生成随机数种子
st0 = np.random.get_state()
millis = int(round(time.time() * 1000)) % 4294967295
np.random.seed(millis)
if cfg.TRAIN.ASPECT_GROUPING: # 使用纵横比对图像重排
# __C.TRAIN.ASPECT_GROUPING = False
widths = np.array([r['width'] for r in self._roidb])
heights = np.array([r['height'] for r in self._roidb])
horz = (widths >= heights)
vert = np.logical_not(horz)
horz_inds = np.where(horz)[0]
vert_inds = np.where(vert)[0]
inds = np.hstack((
np.random.permutation(horz_inds),
np.random.permutation(vert_inds)))
inds = np.reshape(inds, (-1, 2))
row_perm = np.random.permutation(np.arange(inds.shape[0]))
inds = np.reshape(inds[row_perm, :], (-1,))
self._perm = inds # 把roidb的索引打乱,造成的shuffle,打乱的索引存储的地方
else: # 将np.arange(len(self._roidb))打乱
self._perm = np.random.permutation(np.arange(len(self._roidb)))
# Restore the random state
if self._random:
np.random.set_state(st0)
self._cur = 0 # 相当于一个指向_perm的指针,每次取走图片后,他会跟着变化
# 依次取IMS_PER_BATCH大小个roi的index
def _get_next_minibatch_inds(self):
"""Return the roidb indices for the next minibatch."""
# IMS_PER_BATCH应该在train_rpn里面改成了1
if self._cur + cfg.TRAIN.IMS_PER_BATCH >= len(self._roidb): # 越界
self._shuffle_roidb_inds()
db_inds = self._perm[self._cur:self._cur + cfg.TRAIN.IMS_PER_BATCH] # 下一次roidb的下标
# 相当于一个指向_perm的指针,随着取走图片的个数而变化
self._cur += cfg.TRAIN.IMS_PER_BATCH
return db_inds
def _get_next_minibatch(self):
"""Return the blobs to be used for the next minibatch.
并没有搜到 cfg.TRAIN.USE_PREFETCH
If cfg.TRAIN.USE_PREFETCH is True, then blobs will be computed in a
separate process and made available through self._blob_queue.
"""
# 获取下一批图片的索引
db_inds = self._get_next_minibatch_inds()
# 把对应引索的图像信息(dict)取出来,放入一个列表中
minibatch_db = [self._roidb[i] for i in db_inds] # 我感觉里面只有一个图像的信息'.'!
return get_minibatch(minibatch_db, self._num_classes)
def forward(self):
"""Get blobs and copy them into this layer's top blob vector."""
blobs = self._get_next_minibatch()
return blobs
Faster Rcnn 代码解读之 roi_data_layer/layer.py
猜你喜欢
转载自blog.csdn.net/qq_33193309/article/details/98877244
今日推荐
周排行