学习Faster R-CNN代码faster_rcnn(八)

Faster R-CNN源代码中faster_rcnn文件夹中包含三个文件 faster_rcnn.py,resnet.pyvgg16.py

1.faster_rcnn.py注释

  1 class _fasterRCNN(nn.Module):
  2     """ faster RCNN """
  3     def __init__(self, classes, class_agnostic):#class-agnostic 方式只回归2类bounding box,即前景和背景
  4         super(_fasterRCNN, self).__init__()
  5         self.classes = classes #类别
  6         self.n_classes = len(classes)#类别数
  7         self.class_agnostic = class_agnostic #前景背景类
  8         # loss 两种loss
  9         self.RCNN_loss_cls = 0
 10         self.RCNN_loss_bbox = 0
 11 
 12         # define rpn 定义RPN网络
 13         self.RCNN_rpn = _RPN(self.dout_base_model)
 14         self.RCNN_proposal_target = _ProposalTargetLayer(self.n_classes)#候选区域对应gt
 15         self.RCNN_roi_pool = _RoIPooling(cfg.POOLING_SIZE, cfg.POOLING_SIZE, 1.0/16.0)#POOLING
 16         self.RCNN_roi_align = RoIAlignAvg(cfg.POOLING_SIZE, cfg.POOLING_SIZE, 1.0/16.0)
 17 
 18         self.grid_size = cfg.POOLING_SIZE * 2 if cfg.CROP_RESIZE_WITH_MAX_POOL else cfg.POOLING_SIZE
 19         self.RCNN_roi_crop = _RoICrop()
 20 
 21     def forward(self, im_data, im_info, gt_boxes, num_boxes):#图像 图像信息 标注信息 框数目
 22         batch_size = im_data.size(0)
 23 
 24         im_info = im_info.data
 25         gt_boxes = gt_boxes.data
 26         num_boxes = num_boxes.data
 27 
 28         # feed image data to base model to obtain base feature map
 29         #将图像数据馈送到基础模型以获得基础特征图
 30         base_feat = self.RCNN_base(im_data)
 31 
 32         # feed base feature map tp RPN to obtain rois
 33         # 特征图反馈到RPN得到ROIS
 34         rois, rpn_loss_cls, rpn_loss_bbox = self.RCNN_rpn(base_feat, im_info, gt_boxes, num_boxes)
 35 
 36         # if it is training phrase, then use ground trubut bboxes for refining
 37         #如果是在训练 用ground truth回归
 38         if self.training:
 39             roi_data = self.RCNN_proposal_target(rois, gt_boxes, num_boxes)
 40             rois, rois_label, rois_target, rois_inside_ws, rois_outside_ws = roi_data
 41 
 42             rois_label = Variable(rois_label.view(-1).long())
 43             rois_target = Variable(rois_target.view(-1, rois_target.size(2)))
 44             rois_inside_ws = Variable(rois_inside_ws.view(-1, rois_inside_ws.size(2)))
 45             rois_outside_ws = Variable(rois_outside_ws.view(-1, rois_outside_ws.size(2)))
 46         else:
 47             rois_label = None
 48             rois_target = None
 49             rois_inside_ws = None
 50             rois_outside_ws = None
 51             rpn_loss_cls = 0
 52             rpn_loss_bbox = 0
 53 
 54         rois = Variable(rois)
 55         # do roi pooling based on predicted rois
 56         #进行ROI POOLING,下面pooling方式
 57 
 58         if cfg.POOLING_MODE == 'crop':
 59             # pdb.set_trace()
 60             # pooled_feat_anchor = _crop_pool_layer(base_feat, rois.view(-1, 5))
 61             grid_xy = _affine_grid_gen(rois.view(-1, 5), base_feat.size()[2:], self.grid_size)
 62             grid_yx = torch.stack([grid_xy.data[:,:,:,1], grid_xy.data[:,:,:,0]], 3).contiguous()
 63             pooled_feat = self.RCNN_roi_crop(base_feat, Variable(grid_yx).detach())
 64             if cfg.CROP_RESIZE_WITH_MAX_POOL:
 65                 pooled_feat = F.max_pool2d(pooled_feat, 2, 2)
 66         elif cfg.POOLING_MODE == 'align':
 67             pooled_feat = self.RCNN_roi_align(base_feat, rois.view(-1, 5))
 68         elif cfg.POOLING_MODE == 'pool':
 69             pooled_feat = self.RCNN_roi_pool(base_feat, rois.view(-1,5))
 70 
 71         # feed pooled features to top model
 72         #pooling后的特征反馈到上次模型
 73         pooled_feat = self._head_to_tail(pooled_feat)
 74 
 75         # compute bbox offset
 76         #计算bounding box的偏移
 77         bbox_pred = self.RCNN_bbox_pred(pooled_feat)
 78         if self.training and not self.class_agnostic:
 79             # select the corresponding columns according to roi labels
 80             # 根据roi标签选择相应的列
 81             bbox_pred_view = bbox_pred.view(bbox_pred.size(0), int(bbox_pred.size(1) / 4), 4)
 82             bbox_pred_select = torch.gather(bbox_pred_view, 1, rois_label.view(rois_label.size(0), 1, 1).expand(rois_label.size(0), 1, 4))
 83             bbox_pred = bbox_pred_select.squeeze(1)
 84 
 85         # compute object classification probability
 86         # 计算对象分类概率
 87         cls_score = self.RCNN_cls_score(pooled_feat)
 88         cls_prob = F.softmax(cls_score, 1)
 89 
 90         RCNN_loss_cls = 0
 91         RCNN_loss_bbox = 0
 92 
 93         if self.training:
 94             # classification loss
 95             #分类损失
 96             RCNN_loss_cls = F.cross_entropy(cls_score, rois_label)
 97 
 98             # bounding box regression L1 loss
 99             #回归损失
100             RCNN_loss_bbox = _smooth_l1_loss(bbox_pred, rois_target, rois_inside_ws, rois_outside_ws)
101 
102 
103         cls_prob = cls_prob.view(batch_size, rois.size(1), -1)
104         bbox_pred = bbox_pred.view(batch_size, rois.size(1), -1)
105 
106         return rois, cls_prob, bbox_pred, rpn_loss_cls, rpn_loss_bbox, RCNN_loss_cls, RCNN_loss_bbox, rois_label
107     
108     #初始化权重
109     def _init_weights(self):
110         def normal_init(m, mean, stddev, truncated=False):#均值 标准差 
111             #截断正态 随机正态
112             """
113             weight initalizer: truncated normal and random normal.
114             """
115             # x is a parameter
116             if truncated:
117                 m.weight.data.normal_().fmod_(2).mul_(stddev).add_(mean) # not a perfect approximation
118             else:
119                 m.weight.data.normal_(mean, stddev)
120                 m.bias.data.zero_()
121 
122         normal_init(self.RCNN_rpn.RPN_Conv, 0, 0.01, cfg.TRAIN.TRUNCATED)
123         normal_init(self.RCNN_rpn.RPN_cls_score, 0, 0.01, cfg.TRAIN.TRUNCATED)
124         normal_init(self.RCNN_rpn.RPN_bbox_pred, 0, 0.01, cfg.TRAIN.TRUNCATED)
125         normal_init(self.RCNN_cls_score, 0, 0.01, cfg.TRAIN.TRUNCATED)
126         normal_init(self.RCNN_bbox_pred, 0, 0.001, cfg.TRAIN.TRUNCATED)
127 
128     def create_architecture(self):
129         self._init_modules()
130         self._init_weights()

ref:https://blog.csdn.net/weixin_43872578/article/details/87930953

猜你喜欢

转载自www.cnblogs.com/wind-chaser/p/11360073.html