Learning Faster R-CNN codes roi_pooling (II)

roi_pooling relatively simple to understand, so I look at this part of the code.

roi_pooling directory

Under -src folder is c and cuda version of the source.

roi_pool.py -functions folder is inherited torch.autograd.Function class, achieve foward and backward functions RoI layer. class RoIPoolFunction (Function).

roi_pool.py -modules folder torch.nn.Modules inherited class, to achieve the encapsulation layer RoI, just at this time RoI ReLU layer uses the same layer. class _RoIPooling (Module).

-_ext roi_pooling folder there is a folder, this folder is stored in src c, documents compiled after the cuda, after compilation can be funcitons in roi_pool.py call.

Specific code:
shown below: is a length roi_pooling / src / roi_pooling.c document, comprising the variables can understand what the basis of the code. If you look at the code py, in some places really do not know what information is included, it seems very hard, carefully read C code, solve their doubts.

rois [0] num_rois i.e. the number roi;
ROIs [. 1] size_rois i.e. size;
Features [0] the batch_size, batch size;
Features [. 1] data_height, high;
Features [2] DATA_WIDTH, broad;
Features [. 3] NUM_CHANNELS, number of channels.

I see a little careful understanding of another blog: https: //blog.csdn.net/weixin_43872578/article/details/86628515

在这里插入图片描述

functions/roi_pool.py

 1 import torch
 2 from torch.autograd import Function
 3 from .._ext import roi_pooling
 4 import pdb
 5 
 6 # 重写函数实现RoI层的正向传播和反向传播 modules中的roi_pool实现层的封装
 7 
 8 class RoIPoolFunction(Function):
 9     def __init__(ctx, pooled_height, pooled_width, spatial_scale):
10         #ctx is a context object that can be used to stash information for backward computation
11         #上下文对象,可用于存储信息以进行反向计算
12         ctx.pooled_width = pooled_width
13         ctx.pooled_height = pooled_height
14         ctx.spatial_scale = spatial_scale
15         ctx.feature_size = None
16 
17     def forward(ctx, features, rois): 
18         ctx.feature_size = features.size()          
19         batch_size, num_channels, data_height, data_width = ctx.feature_size
20         num_rois = rois.size(0)
21         output = features.new(num_rois, num_channels, ctx.pooled_height, ctx.pooled_width).zero_()
22         ctx.argmax = features.new(num_rois, num_channels, ctx.pooled_height, ctx.pooled_width).zero_().int()
23         ctx.rois = rois
24         if not features.is_cuda:
25             _features = features.permute(0, 2, 3, 1)
26             roi_pooling.roi_pooling_forward(ctx.pooled_height, ctx.pooled_width, ctx.spatial_scale,
27                                             _features, rois, output)
28         else:
29             roi_pooling.roi_pooling_forward_cuda(ctx.pooled_height, ctx.pooled_width, ctx.spatial_scale,
30                                                  features, rois, output, ctx.argmax)
31 
32         return output
33 
34     def backward(ctx, grad_output):
35         assert(ctx.feature_size is not None and grad_output.is_cuda)
36         batch_size, num_channels, data_height, data_width = ctx.feature_size
37         grad_input = grad_output.new(batch_size, num_channels, data_height, data_width).zero_()
38 
39         roi_pooling.roi_pooling_backward_cuda(ctx.pooled_height, ctx.pooled_width, ctx.spatial_scale,
40                                               grad_output, ctx.rois, grad_input, ctx.argmax)
41 
42         return grad_input, None

modules/roi_pool.py

 1 from torch.nn.modules.module import Module
 2 from ..functions.roi_pool import RoIPoolFunction
 3 
 4 # 对roi_pooling层的封装,就是ROI Pooling Layer了
 5 
 6 class _RoIPooling(Module):
 7     def __init__(self, pooled_height, pooled_width, spatial_scale):
 8         super(_RoIPooling, self).__init__()
 9 
10         self.pooled_width = int(pooled_width)
11         self.pooled_height = int(pooled_height)
12         self.spatial_scale = float(spatial_scale)
13 
14     def forward(self, features, rois):
15         return RoIPoolFunction(self.pooled_height, self.pooled_width, self.spatial_scale)(features, rois)
16         # 直接调用了functions中的函数,此时已经实现了foward,backward操作

剩下的src,_ext文件的代码就可以自己读读了,就是用c,cuda对roi_pooling实现了foward和backward,目的就是为了让python可以调用。

【未完,待更新…】

ref:https://www.jianshu.com/p/d674e16ce896

https://blog.csdn.net/weixin_43872578/article/details/86616801

Guess you like

Origin www.cnblogs.com/wind-chaser/p/11354955.html