DETR3D的关键:feature_sampling(3D到2D的图像特征采样)

流程:坐标转换->归一化坐标->grid_sample()方法->mask

坐标转换->归一化坐标->grid_sample()方法

# 特征采样部分, Input queries from different level. Each element has shape [bs, embed_dims, h, w] 也就是[4, bs, embed_dims, h, w]
# 特征采样部分, Input queries from different level. Each element has shape [bs, embed_dims, h, w] 也就是[4, bs, embed_dims, h, w]
def feature_sampling(mlvl_feats, reference_points, pc_range, img_metas):
    lidar2img = []
    # lidar2img:3D坐标以lidar为中心,求出3D点到img的转换关系也就是求出lidar到img的转换关系
    for img_meta in img_metas:
        lidar2img.append(img_meta['lidar2img'])
    lidar2img = np.asarray(lidar2img)
    # N = 6,referrence_points:[bs, num_query, 3]
    lidar2img = reference_points.new_tensor(lidar2img) # (B, N, 4, 4)
    reference_points = reference_points.clone()
    reference_points_3d = reference_points.clone()
    #  还原归一化后的[x,y,z]
    reference_points[..., 0:1] = reference_points[..., 0:1]*(pc_range[3] - pc_range[0]) + pc_range[0]
    reference_points[..., 1:2] = reference_points[..., 1:2]*(pc_range[4] - pc_range[1]) + pc_range[1]
    reference_points[..., 2:3] = reference_points[..., 2:3]*(pc_range[5] - pc_range[2]) + pc_range[2]
    # reference_points [bs, num_query, 3] -> [bs, num_query, 4]
    reference_points = torch.cat((reference_points, torch.ones_like(reference_points[..., :1])), -1)
    B, num_query = reference_points.size()[:2]
    # num_cam = 6
    num_cam = lidar2img.size(1)
    # from [b,1,num_query,4] to [b,num_cam,num_query, 4, 1],将相同的query投影到每个camera中,因为每个query对应的是lidar坐标系的三维点
    reference_points = reference_points.view(B, 1, num_query, 4).repeat(1, num_cam, 1, 1).unsqueeze(-1)
    # shape:[b, num_cam, num_query, 4, 4]
    lidar2img = lidar2img.view(B, num_cam, 1, 4, 4).repeat(1, 1, num_query, 1, 1)

    # project 3d -> 2d
    # shape:[b, num_cam, num_query, 4],计算得到每个三维点在每个相机的投影
    reference_points_cam = torch.matmul(lidar2img, reference_points).squeeze(-1)
    eps = 1e-5 # eps作为无穷小数
    mask = (reference_points_cam[..., 2:3] > eps)

    # 避免除以0的问题, reference_points_cam[..., 2:3]指的是深度,要获得z*[x,y,1,1] = cam_intrinsics[4x4] * [x,y,z,1] -> z*[x,y,1,1] / z -> [x,y,1,1]
    reference_points_cam = reference_points_cam[..., 0:2] / torch.maximum(
        reference_points_cam[..., 2:3], torch.ones_like(reference_points_cam[..., 2:3])*eps)
    # 归一化[0,1]
    reference_points_cam[..., 0] /= img_metas[0]['img_shape'][0][1]
    reference_points_cam[..., 1] /= img_metas[0]['img_shape'][0][0]

    # 确保reference在[-1, 1]之间,原来在[0,1]之间
    reference_points_cam = (reference_points_cam - 0.5) * 2

    # 对所有不在grid内的点,也就是投影在某个cam之外的点进行mask
    mask = (mask & (reference_points_cam[..., 0:1] > -1.0) 
                 & (reference_points_cam[..., 0:1] < 1.0) 
                 & (reference_points_cam[..., 1:2] > -1.0) 
                 & (reference_points_cam[..., 1:2] < 1.0))
    mask = mask.view(B, num_cam, 1, num_query, 1, 1).permute(0, 2, 3, 1, 4, 5) # [bs, 1, n_q, n_c, 1, 1]
    mask = torch.nan_to_num(mask)
    sampled_feats = []
    # 对四个特征层分别求出线性插值后的feature,其中N为num_query, [4, bs, embed_dims, h, w]
    for lvl, feat in enumerate(mlvl_feats):
        B, N, C, H, W = feat.size() # (num_key, bs, embed_dims)
        # N=num_cam
        feat = feat.view(B*N, C, H, W)
        # [b,num_cam,num_query,2] -> [b, num_cam, num_query, 1, 2]
        reference_points_cam_lvl = reference_points_cam.view(B*N, num_query, 1, 2)
        # F.grid_sample return:[b*n,c,num_query,1]每个query对应着一个grid采样(bilinear incorparation)后返回的值
        sampled_feat = F.grid_sample(feat, reference_points_cam_lvl)
        # b,c,n_q,n,1
        sampled_feat = sampled_feat.view(B, N, C, num_query, 1).permute(0, 2, 3, 1, 4)
        sampled_feats.append(sampled_feat)
    # [b,n,c,num_query,len(mlvl_feats)]
    sampled_feats = torch.stack(sampled_feats, -1)
    sampled_feats = sampled_feats.view(B, C, num_query, num_cam,  1, len(mlvl_feats)) 
    return reference_points_3d, sampled_feats, mask

mask

attention_weights = attention_weights.sigmoid() * mask
@ATTENTION.register_module()
class Detr3DCrossAtten(BaseModule):

    def __init__(self,
                 embed_dims=256,
                 num_heads=8,
                 num_levels=4,
                 num_points=5,
                 num_cams=6,
                 im2col_step=64,
                 pc_range=None,
                 dropout=0.1,
                 norm_cfg=None,
                 init_cfg=None,
                 batch_first=False):
        super(Detr3DCrossAtten, self).__init__(init_cfg)

    def forward(self,
                query,
                key,
                value,
                residual=None,
                query_pos=None,
                key_padding_mask=None,
                reference_points=None,
                spatial_shapes=None,
                level_start_index=None,
                **kwargs):
        # 输入的是query和reference points
        if key is None:
            key = query
        if value is None:
            value = key

        if residual is None:
            inp_residual = query
        if query_pos is not None:
            query = query + query_pos

        # change to (bs, num_query, embed_dims)
        query = query.permute(1, 0, 2)

        bs, num_query, _ = query.size()

        attention_weights = self.attention_weights(query).view(
            bs, 1, num_query, self.num_cams, self.num_points, self.num_levels)
        
        # shape:(B, N, C, H, W), [bs, num_query, 3]
        reference_points_3d, output, mask = feature_sampling(
            value, reference_points, self.pc_range, kwargs['img_metas'])
        output = torch.nan_to_num(output)
        mask = torch.nan_to_num(mask)

        # 对每个camera中投影在之外的点进行mask
        attention_weights = attention_weights.sigmoid() * mask
        output = output * attention_weights
        output = output.sum(-1).sum(-1).sum(-1) # sum后缩减三个维度:shape:[bs, c, num_query]
        output = output.permute(2, 0, 1) # [num_query, bs, c]
        
        output = self.output_proj(output) # (num_query, bs, embed_dims),将reference3d的dim转换到256
        # output作为fetch的feature,与经过encoder后的query、原始query直接相加作为refinement query
        pos_feat = self.position_encoder(inverse_sigmoid(reference_points_3d)).permute(1, 0, 2)

        return self.dropout(output) + inp_residual + pos_feat

猜你喜欢

转载自blog.csdn.net/weixin_43253464/article/details/125684501
今日推荐