流程:坐标转换->归一化坐标->grid_sample()方法->mask
坐标转换->归一化坐标->grid_sample()方法
def feature_sampling(mlvl_feats, reference_points, pc_range, img_metas):
lidar2img = []
for img_meta in img_metas:
lidar2img.append(img_meta['lidar2img'])
lidar2img = np.asarray(lidar2img)
lidar2img = reference_points.new_tensor(lidar2img)
reference_points = reference_points.clone()
reference_points_3d = reference_points.clone()
reference_points[..., 0:1] = reference_points[..., 0:1]*(pc_range[3] - pc_range[0]) + pc_range[0]
reference_points[..., 1:2] = reference_points[..., 1:2]*(pc_range[4] - pc_range[1]) + pc_range[1]
reference_points[..., 2:3] = reference_points[..., 2:3]*(pc_range[5] - pc_range[2]) + pc_range[2]
reference_points = torch.cat((reference_points, torch.ones_like(reference_points[..., :1])), -1)
B, num_query = reference_points.size()[:2]
num_cam = lidar2img.size(1)
reference_points = reference_points.view(B, 1, num_query, 4).repeat(1, num_cam, 1, 1).unsqueeze(-1)
lidar2img = lidar2img.view(B, num_cam, 1, 4, 4).repeat(1, 1, num_query, 1, 1)
reference_points_cam = torch.matmul(lidar2img, reference_points).squeeze(-1)
eps = 1e-5
mask = (reference_points_cam[..., 2:3] > eps)
reference_points_cam = reference_points_cam[..., 0:2] / torch.maximum(
reference_points_cam[..., 2:3], torch.ones_like(reference_points_cam[..., 2:3])*eps)
reference_points_cam[..., 0] /= img_metas[0]['img_shape'][0][1]
reference_points_cam[..., 1] /= img_metas[0]['img_shape'][0][0]
reference_points_cam = (reference_points_cam - 0.5) * 2
mask = (mask & (reference_points_cam[..., 0:1] > -1.0)
& (reference_points_cam[..., 0:1] < 1.0)
& (reference_points_cam[..., 1:2] > -1.0)
& (reference_points_cam[..., 1:2] < 1.0))
mask = mask.view(B, num_cam, 1, num_query, 1, 1).permute(0, 2, 3, 1, 4, 5)
mask = torch.nan_to_num(mask)
sampled_feats = []
for lvl, feat in enumerate(mlvl_feats):
B, N, C, H, W = feat.size()
feat = feat.view(B*N, C, H, W)
reference_points_cam_lvl = reference_points_cam.view(B*N, num_query, 1, 2)
sampled_feat = F.grid_sample(feat, reference_points_cam_lvl)
sampled_feat = sampled_feat.view(B, N, C, num_query, 1).permute(0, 2, 3, 1, 4)
sampled_feats.append(sampled_feat)
sampled_feats = torch.stack(sampled_feats, -1)
sampled_feats = sampled_feats.view(B, C, num_query, num_cam, 1, len(mlvl_feats))
return reference_points_3d, sampled_feats, mask
mask
attention_weights = attention_weights.sigmoid() * mask
@ATTENTION.register_module()
class Detr3DCrossAtten(BaseModule):
def __init__(self,
embed_dims=256,
num_heads=8,
num_levels=4,
num_points=5,
num_cams=6,
im2col_step=64,
pc_range=None,
dropout=0.1,
norm_cfg=None,
init_cfg=None,
batch_first=False):
super(Detr3DCrossAtten, self).__init__(init_cfg)
def forward(self,
query,
key,
value,
residual=None,
query_pos=None,
key_padding_mask=None,
reference_points=None,
spatial_shapes=None,
level_start_index=None,
**kwargs):
if key is None:
key = query
if value is None:
value = key
if residual is None:
inp_residual = query
if query_pos is not None:
query = query + query_pos
query = query.permute(1, 0, 2)
bs, num_query, _ = query.size()
attention_weights = self.attention_weights(query).view(
bs, 1, num_query, self.num_cams, self.num_points, self.num_levels)
reference_points_3d, output, mask = feature_sampling(
value, reference_points, self.pc_range, kwargs['img_metas'])
output = torch.nan_to_num(output)
mask = torch.nan_to_num(mask)
attention_weights = attention_weights.sigmoid() * mask
output = output * attention_weights
output = output.sum(-1).sum(-1).sum(-1)
output = output.permute(2, 0, 1)
output = self.output_proj(output)
pos_feat = self.position_encoder(inverse_sigmoid(reference_points_3d)).permute(1, 0, 2)
return self.dropout(output) + inp_residual + pos_feat