[BEV] 学习笔记之BEVFormer(二)

1、前言

在上一篇中介绍了BEVFormer的大体流程,地址为:https://zhuanlan.zhihu.com/p/593998659,由于本项目中涉及到许多变量重复且变量名重复使用,导致在代码阅读中会有一定的难度,本文注重在给关键的变量进行注释,下文中的内容仅仅我个人的理解,如果有错误的地方,烦请各位大佬说明并进行改正。
本人也是初学者,欢迎正在学习或者想学习BEV模型的朋友加入交流群一起讨论、学习论文或者代码实现中的问题 ,可以加 v群:Rex1586662742,q群:468713665

2、forward 过程详解

本文依旧是按照forward的过程对变量进行说明。

1、tools/test.py

outputs = custom_multi_gpu_test(model, data_loader, args.tmpdir,args.gpu_collect)
# 进入到projects/mmdet3d_plugin/bevformer/apis/test.py

2、projects/mmdet3d_plugin/bevformer/apis/test.py

def custom_multi_gpu_test(...):
    ...
    for i, data in enumerate(data_loader):
    with torch.no_grad():
        result = model(return_loss=False, rescale=True, **data)
        # 进入到 projects/mmdet3d_plugin/bevformer/detectors/bevformer.py
        ...

3、projects/mmdet3d_plugin/bevformer/detectors/bevformer.py

def forward(...):
    if return_loss:
        return self.forward_train(**kwargs)
    else:
        return self.forward_test(**kwargs)
        # 进入到 self.forward_test 中
    
def forward_test(...):
    ...
    # forward
    new_prev_bev, bbox_results = self.simple_test(...)
    ...
def simple_test(...):
    # self.extract_feat 主要包括两个步骤 img_backbone、img_neck,通过卷积提取特征
    # 网络为resnet + FPN
    # 如果是base模型,img_feats 为四个不同尺度的特征层
    # 如果是small、tiny,img_feats 为一个尺度的特征层
    img_feats = self.extract_feat(img=img, img_metas=img_metas)
    # Temproral Self-Attention + Spatial Cross-Attention
    new_prev_bev, bbox_pts = self.simple_test_pts(
            img_feats, img_metas, prev_bev, rescale=rescale)
def simple_test_pts(...):
    # 对特征层进行编解码
    outs = self.pts_bbox_head(x, img_metas, prev_bev=prev_bev)
    # 进入到 projects/mmdet3d_plugin/bevformer/dense_heads/bevformer_head.py

4、projects/mmdet3d_plugin/bevformer/dense_heads/bevformer_head.py

class BEVFormerHead(DETRHead):
    def __init__(...):
       	if not self.as_two_stage:
           # 可学习的位置编码
			self.bev_embedding = nn.Embedding(self.bev_h * self.bev_w, self.embed_dims)
            self.query_embedding = nn.Embedding(self.num_query,self.embed_dims * 2)
    def forward(...):
        '''
        mlvl_feats: (tuple[Tensor]) FPN网络输出的多尺度特征
        prev_bev: 上一时刻的 bev_features
        all_cls_scores: 所有的类别得分信息
        all_bbox_preds: 所有预测框信息
        '''
        # 特征编码 (900,512)  (900,256) concate (900 + 256)
        object_query_embeds = self.query_embedding.weight.to(dtype)
        # [2500,256] bev特征图的大小,最终bev的大小为 50*50,每个点的channel维度为256。(base模型的特征图大小为200 * 200)
        bev_queries = self.bev_embedding.weight.to(dtype)
        # [1,50,50] 每个特征点对应一个mask点
        bev_mask = torch.zeros((bs, self.bev_h, self.bev_w),         device=bev_queries.device).to(dtype)
        # [1, 256, 50, 50] 可学习的位置编码
        bev_pos = self.positional_encoding(bev_mask).to(dtype)
        if only_bev:
            ...
        else:
            # mlvl_feats ,多尺度特征
            # bev_queries ,200*200,256
            # object_query_embeds = 900 * 512 # 检测头使用的部分
            outputs = ...
        outputs = self.transformer(...)
        # 进入到 projects/mmdet3d_plugin/bevformer/modules/transformer.py
        
        for lvl in range(hs.shape[0]):
            # 类别
            outputs_class = self.cls_branches[lvl](hs[lvl])
            # 回归框信息
            tmp = self.reg_branches[lvl](hs[lvl])
            

5、projects/mmdet3d_plugin/bevformer/modules/transformer.py

class PerceptionTransformer(...):
    def __init__(...):
        ...
    def forward(...):
       # 获得bev特征 temporal_self_attention + spatial_cross_attention
        bev_embed = self.get_bev_features(...)
    def get_bev_features(...):
        # 车身底盘信号:速度、加速度等
        # 当前帧的bev特征与历史特征进行  时间、空间上的对齐
        delta_x = ...
        # BEV特征中 每一格 在真实世界中对应的长度
        grid_length_x = 0.512
        grid_length_x = 0.512
        # 上帧和当前帧的偏移量
        shift_x = ...
        shift_y = ...
        if prev_bev is not None:
            ...
            if self.rotate_prev_bev:
                # 车身旋转角度
                rotation_angle = ...
        # can信号映射到 256维度
        can_bus = self.can_bus_mlp(can_bus)[None, :, :]
        # bev特征加上can_bus特征
        bev_queries = bev_queries + can_bus * self.use_can_bus
        
        # sca 有关
        for lvl, feat in enumerate(mlvl_feats):
            # 特征编码
            if self.use_cams_embeds:
                feat = feat + self.cams_embeds[:, None, None, :].to(feat.dtype)
             feat = feat + self.level_embeds[None, None, lvl:lvl + 1, :].to(feat.dtype)
        # 每一个维度的起始点
        level_start_index = ...
        
        # 获得bev特征 block * 6
        bev_embed = self.encoder(...)
        # 进入到projects/mmdet3d_plugin/bevformer/modules/encoder.py
        ...
        # decoder
        inter_states, inter_references = self.decoder(...)
        # 进入到 projects/mmdet3d_plugin/bevformer/modules/decoder.py 中
        
        return bev_embed, inter_states, init_reference_out, inter_references_out
        # 返回到projects/mmdet3d_plugin/bevformer/dense_heads/bevformer_head.py

6、projects/mmdet3d_plugin/bevformer/modules/encoder.py

class BEVFormerEncoder(...):
    def __init__(self):
        ...
    def get_reference_points(...):
        '''
        获得参考点用于 SCA以及TSA
        H:bev_h
        W:bev_w
        Z:pillar的高度
        num_points_in_pillar:4,在每个pillar里面采样四个点
        '''
        # SCA
        if dim == '3d':
            # (4, 50, 50) 为每一个bev_query特征点在0~Z上均匀采样4个点,并归一化
            zs = ... 
            # 均匀采样的x坐标
            xs = ...
            # 均匀采样的y坐标
            ys = ...
            # (1, 4, 2500, 3)
            ref_3d = 
        # TSA
        elif dim == '2d':
            # bev特征点坐标
            ref_2d = ...
    def point_sampling(...)
        '''
        pc_range: bev特征表征的真实的物理空间大小
        img_metas: 数据集 list [(4*4)] * 6
        '''
        # 4×4 为 雷达坐标系转图像坐标系的齐次矩阵
        # 采用lidar 的坐标系
        lidar2img = ...
        # 参考坐标转化的尺度转化为真实尺度
        # [x, y, z, 1]
        reference_points = ...
        # (4,4) * [x,y,z,1] -> (zc * u , zc * v, zc, 1)  像素空间
        reference_points_cam = torch.matmul(lidar2img.to(torch.float32),  reference_points.to(torch.float32)).squeeze(-1)
        # 通过阈值判断,对bev_query的每个坐标进行 #判断,高于阈值的为True,否则为False,用于减少计算量
        # zc 大于 eps 的 为true
        bev_mask = (reference_points_cam[..., 2:3] > eps)
        # 0~1之间
        reference_points_cam[..., 0] /= img_metas[0]['img_shape'][0][1]
        reference_points_cam[..., 1] /= img_metas[0]['img_shape'][0][0]
        # 确保所有点在正确范围内 
        bev_mask = (bev_mask & (reference_points_cam[..., 1:2] > 0.0)
                    & (reference_points_cam[..., 1:2] < 1.0)
                    & (reference_points_cam[..., 0:1] < 1.0)
                    & (reference_points_cam[..., 0:1] > 0.0))
        ...
    # 先进入到这个forward
    @auto_fp16()
    def forward(...):
        '''
        bev_query: (2500, 1, 256)
        key: (6, 375, 1, 256) 6个相机图片的特征
        value: 与key一致
        bev_pos:(2500, 1, 256) 为每个bev特征点进行可学习的编码  
        spatial_shapes: 相机特征层的尺度,tiny模型只有一个,base模型有4个
        level_start_index: 特征尺度的索引
        prev_bev:(2500, 1, 256) 前一时刻的bev_query
        shift: 当前bev特征相对于上一时刻bev特征的偏移量
        '''
        # z轴的采样点坐标 (1, 4, 2500, 3) 
        ref_3d = self.get_reference_points(...)
        # bev_query 特征点的归一化坐标 (1, 2500, 1, 2)
        ref_2d = self.get_reference_points(...)
        # (6,1,40000,4,2) 像素坐标
        reference_points_cam, bev_mask = self.point_sampling(...)
        # 当前bev特征坐标等于上一时刻bev特征+偏移量
        # 通过偏移量,可以将当前帧的bev特征点与上一帧的bev特征点联系起来
        shift_ref_2d += shift[:, None, None, :]
        if prev_bev is not None:
            # 叠加当前时刻bev_query 和上一时刻的bev_query
            prev_bev = torch.stack([prev_bev, bev_query], 1).reshape(bs*2, len_bev, -1)
        # 6 × encoder
        for lid, layer in enumerate(self.layers):
            # 进入到下面的 BEVFormerLayer 的forward中
            output = layer(...)
            

class BEVFormerLayer(MyCustomBaseTransformerLayer)
    def __init__(...):
        '''
        attn_cfgs:来自总体网络配置文件的参数
        ffn_cfgs:单层神经网络的参数
        operation_order: 'self_attn', 'norm', 'cross_attn', 'norm', 'ffn', 'norm',encode中每个block中包含的步骤
        '''
        # 注意力模块的个数 2
        self.num_attn 
        # 编码维度 256
        self.embed_dims
        # ffn层
        self.ffns
        # norn层
        self.norms 
        ...
    def forward(...):
        '''
        query:当前时刻的bev_query,(1, 2500, 256)
        key: 当前时刻6个相机的特征,(6, 375, 1, 256)
        value:当前时刻6个相机的特征,(6, 375, 1, 256)
        bev_pos:每个bev_query特征点 可学习的位置编码
        ref_2d:前一时刻和当前时刻bev_query对应的参考点  (2, 2500, 1, 2)
        red_3d: 当前时刻在Z轴上的采样的参考点 (1, 4, 2500, 3) 每个特征点在z轴沙漠化采样4个点
        bev_h: 50
        bev_w: 50
        reference_points_cam: (6, 1, 2500, 4, 2) 
        spatial_shapes:FPN特征层大小 [15,25]
        level_start_index: [0] spatial_shapes对应的索引
        prev_bev: 上上个时刻以及上个时刻 bev_query(2, 2500, 256)  
        '''
        # 遍历六个 encoder的 block块
        for layer in self.operation_order:
            # 首先进入tmporal_self_attention
            if layer == 'self_attn':
                # self.attentions 为 temporal_self_attention模块
                query = self.attentions[attn_index]
                # 进入到projects/mmdet3d_plugin/bevformer/modules/temporal_self_attention.py
            #  Spatial Cross-Attention
            # 然后进入  Spatial Cross-Attention
            elif layer == 'cross_attn':
                query = self.attentions[attn_index]
                # 进入到 projects/mmdet3d_plugin/bevformer/modules/spatial_cross_attention.py
        

7、projects/mmdet3d_plugin/bevformer/modules/temporal_self_attention.py

class TemporalSelfAttention(...):
    def __init__(...):
        '''
        embed_dims: bev特征维度 256
        num_heads: 8 头注意力
        num_levels:1 多尺度特征的层数
        num_points:4,每个特征点采样四个点进行计算
        num_bev_queue:bev特征长度,及上一时刻以及当前时刻
        '''
        self.sampling_offsets = nn.Linear(...) # 学习偏置的网络
        self.attention_weights =nn.Linear(...) # 学习注意力特征的网络
        self.value_proj  = nn.Linear(...)  # 学习vaule特征的网络
        self.output_proj = nn.Linear(...)  # 输入结果的网络
    
    
    def forward(...):
        '''
        query: (1, 2500, 256) 当前时刻的bev特征图
        key: (2, 2500, 256)  上一个时刻的以及上上时刻的bev特征
        value: (2, 2500, 256) 上一个时刻的以及上上时刻的bev特征
        query_pos: 可学习的位置编码
        reference_points:每个bev特征点对应的坐标
        '''
        # 初始帧
        if value is None:
            assert self.batch_first
            bs, len_bev, c = query.shape
            value = torch.stack([query, query], 1).reshape(bs*2, len_bev, c)
        # 位置编码
        if query_pos is not None:
            query = query + query_pos
        # 将前一时刻的bev和当前时刻的bev特征进行叠加
        query = torch.cat([value[:bs], query], -1)
        # 学习前一时刻和当前时刻的bev特征 (1, 2500, 128)
        value =  self.value_proj(value)
        # 8 个头的注意力
        value = value.reshape(bs*self.num_bev_queue,
                                  num_value, self.num_heads, -1)
        
        # (1, 2500, 128)
        # 从当前时刻的bev_query 学习到 参考点的偏置
        sampling_offsets = self.sampling_offsets(query)
        # (1, 2500, 8, 2, 1, 4, 2)  
        sampling_offsets = sampling_offsets.view(
                bs, num_query, self.num_heads,  self.num_bev_queue, self.num_levels, self.num_points, 2)
        # (1, 2500, 8, 2, 4)   用于学习每个特征点之间的权重
        attention_weights = self.attention_weights(query).view(
                bs, num_query,  self.num_heads, self.num_bev_queue, self.num_levels * self.num_points)
        # offset_normalizer = (50,50)
        if reference_points.shape[-1] == 2:
            offset_normalizer = torch.stack([spatial_shapes[..., 1], spatial_shapes[..., 0]], -1)
            # reference_points (2, 2500, 1, 2) pre_bev 和当前bev 每个特征点的 归一化坐标 0~1之间 
        #    sampling_locations bev上每个特征点与哪些采样点进行注意力计算
            sampling_locations = reference_points [][:, :, None, :, None, :] + sampling_offsets +  offset_normalizer[None, None, None, :, None, :]
       if ...:
            ...
        else:
            # 计算deformable attention output (2, 2500, 256)
            output = multi_scale_deformable_attn_pytorch(...)
        # (2500, 256, 1, 2) 当前时刻与上个时刻的注意力特征
        output = output.view(num_query, embed_dims, bs, self.num_bev_queue)
        # 将两个时刻的注意力特征取平均值
        output = output.mean(-1)
        # 线性层
        output = self.output_proj(output)
        # 残差链接 
        return self.dropout(output) + identity
        返回到 projects/mmdet3d_plugin/bevformer/modules/encoder.py 中
        
def multi_scale_deformable_attn_pytorch(...):
    # 映射到 -1 到 1之间
    sampling_grids = 2 * sampling_locations - 1
    for level, (H_, W_) in enumerate(value_spatial_shapes):
        # 不规则采样
        sampling_value_l_ = F.grid_sample
    
    # 相乘注意力操作
    output = (torch.stack(sampling_value_list, dim=-2).flatten(-2) *
              attention_weights).sum(-1).view(bs, num_heads * embed_dims,
                                              num_queries)

8、projects/mmdet3d_plugin/bevformer/modules/spatial_cross_attention.py

SpatialCrossAttention(...):
    def __init__(...):
        '''
        embed_dims:编码维度
        pc_range:真实世界的尺度
        deformable_attention: 配置参数
        num_cams:相机数量
        '''
        self.output_proj = nn.Linear(...) # out网络
    
    def forward(...):
        '''
        query:tmporal_self_attention的输出加上 self.norms
        reference_points:(1, 4, 2500, 3)  由 tmporal_self_attention的输出加上 模块计算的z轴上采样点的坐标,每个bev特征的有三个坐标点(x,y,z)
        bev_mask:(6, 1, 2500, 4) 某些特征点的值为false,可以将其过滤掉,2500为bev特征点个数,1为特征尺度,4,为在每个不同尺度的特征层上采样点的个数。
        '''
        # (6, 375, 1, 256)  query 轮巡到 key 上查找特征
        
        # bev_mask.shape (6, 1, 2500, 4)  
        for i, mask_per_img in enumerate(bev_mask):
            # 从每个特征层上找到有效位置的 index
            index_query_per_img = mask_per_img[0].sum(-1).nonzero().squeeze(-1)
            indexes.append(index_query_per_img)
       # bev特征层对应每个  相机特征的 最大的特征数的长度
        max_len = max([len(each) for each in indexes])
        # 将所有 相机的特征点的个数 重建为 最大特征长度
        queries_rebatch = query.new_zeros([bs, self.num_cams, max_len, self.embed_dims])
        # 将query放到   reference_points_rebatch中
        reference_points_rebatch = ...
        for j in range(bs):
            for i, reference_points_per_img in enumerate(reference_points_cam):
                # 将query和 reference_points_cam 中有效的元素提取出来
                ...
        # deformable_attention 
        queries = self.deformable_attention(...)
    
    
# self.deformable_attention
 class MSDeformableAttention3D(BaseModule):
    def __init__(...):
        '''
        embed_dims:编码维度
        num_heads:注意力头数
        num_levels: 4 
        每个z轴上的点要到每一个相机特征图上寻找两个点,所以会有8个点
        '''
        # 学习特征点偏移的网络
        self.sampling_offsets = nn.Linear(...)
        # 提取特征网络
        self.attention_weights(...)
        # 输出特征网络
        self.value_proj = nn.Linear(...)
        
    def forward(...):
        '''
        query: (1,604,256), queries_rebatch 特征筛选过后的query
        query_pos:挑选的特征点的归一化坐标
        '''
        # mlp
        value = self.value_proj(value)
        value = value.view(bs, num_value, self.num_heads, -1)
        # 从bev_query 学习到的偏置 
        sampling_offsets = ...
        # 注意力权重
        attention_weights
        ...
        if ...:
        else:
            output = ...
            ...
        output = multi_scale_deformable_attn_pytorch(
                value, spatial_shapes, sampling_locations, attention_weights)
        ...
        返回到encoder.py中
        

9、projects/mmdet3d_plugin/bevformer/modules/decoder.py

class DetectionTransformerDecoder(...):
    def __init__(...):
        ...
        
    def forward(...):
        '''
        query: [900,1,256] bev 特征
        reference_points: [1, 900, 3] 每个query 对应的 x,y,z坐标
        '''
        # 重复6次decoder
        for lid, layer in enumerate(self.layers):
            # 取x,y
            reference_points_input = reference_points[..., :2].unsqueeze(2)
            
            output = layer(...)
            # 进入到 CustomMSDeformableAttention
            # 在获得查询到的特征后,会利用回归分支(FFN 网络)对提取的特征计算回归结果,预测 10 个输出
            # (xc,yc,w,l,zc,h,rot.sin(),rot.cos(),vx,vy);[预测框中心位置的x方向偏移,预测框中心位置的y方向偏移,预测框的宽,预测框的长,预测框中心位置的z方向偏移,预测框的高,旋转角的正弦值,旋转角的余弦值,x方向速度,y方向速度]
            # 然后根据预测的偏移量,对参考点的位置进行更新,为级联的下一个 Decoder 提高精修过的参考点位置

            new_reference_points = torch.zeros_like(reference_points)
            # 预测出来的偏移量是绝对量
            # 框中心处的 x, y 坐标
            new_reference_points[..., :2] = tmp[..., :2] + inverse_sigmoid(reference_points[..., :2]) 
            # 框中心处的 z 坐标
            new_reference_points[..., 2:3] = tmp[..., 4:5] + inverse_sigmoid(reference_points[..., 2:3]) 
            # 计算归一化坐标
            new_reference_points = new_reference_points.sigmoid()
            reference_points = new_reference_points.detach()
            
            if self.return_intermediate:
                intermediate.append(output)
                intermediate_reference_points.append(reference_points)
            return output, reference_points
            # 返回到 projects/mmdet3d_plugin/bevformer/modules/transformer.py
            
class CustomMSDeformableAttention(...):
    def forward(...):
        '''
        query: (900, 1, 256)
        query_pos:(900, 1, 256) 可学习的位置编码
        '''
        output = multi_scale_deformable_attn_pytorch(...)
        output = self.output_proj(output)
        return self.dropout(output) + identity

4、损失函数

损失函数的计算在https://zhuanlan.zhihu.com/p/543335939中讲的比较详细了,因此本文不再进行叙述,通过对BEVFoer论文以及代码的阅读,基本上弄清楚了工作流程,主要是弄清楚了TSA、SCA是如何实现的,这是笔者详细了解的第一个BEV模型,细节上可能还会有些问题,但BEV模型还在不断更新,不得不去卷其他模型了。

猜你喜欢

转载自blog.csdn.net/weixin_42108183/article/details/128433381