bevfusion transformation 分析

bevfusion是在mmdetection3D的代码框架上的二次开发,但是基于的版本是比较早期的版本,坐标系系统可能比较混乱。如下具体分析下训练过程中的坐标变换,点云数据增强、图像3D数据增强、LSS过程中如上两个变化的使用。

1. 点云的数据增强
 

class GlobalRotScaleTrans:
    def __init__(self, resize_lim, rot_lim, trans_lim, is_train):
        self.resize_lim = resize_lim
        self.rot_lim = rot_lim
        self.trans_lim = trans_lim
        self.is_train = is_train

    def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
        transform = np.eye(4).astype(np.float32)

        if self.is_train:
            scale = random.uniform(*self.resize_lim)
            theta = random.uniform(*self.rot_lim)
            translation = np.array([random.normal(0, self.trans_lim) for i in range(3)])
            rotation = np.eye(3)
            
    
            # 使用base_points类对应的rotate,translate,scale函数对点云进行相应的变换
            # 注意这里用的是-theta,逆时针-theta,也就是顺时针theta
            if "points" in data:
                data["points"].rotate(-theta)
                data["points"].translate(translation)
                data["points"].scale(scale)
                       
            # 使用lidar_boxes类的rotate,translate,scale对box进行相应的变换
            # 注意这里用的是theta,顺时针theta
            gt_boxes = data["gt_bboxes_3d"]
            rotation = rotation @ gt_boxes.rotate(theta).numpy()
            gt_boxes.translate(translation)
            gt_boxes.scale(scale)
            data["gt_bboxes_3d"] = gt_boxes
            
            # 保留变换矩阵
            # 注意,这里rotation加了转置,返回的矩阵是逆时针theta的矩阵
            # 这里转置变成顺时针的矩阵,和上面的变换保持一致
            transform[:3, :3] = rotation.T * scale
            transform[:3, 3] = translation * scale

        data["lidar_aug_matrix"] = transform
        return 

base_points中的rotate

逆时针旋转theta角

# 逆时针旋转theta角
elif axis == 2 or axis == -1:
    rot_mat_T = rotation.new_tensor(
    [[rot_cos, -rot_sin, 0], [rot_sin, rot_cos, 0], [0, 0, 1]]
                )
# 转置,顺时针旋转theta角
rot_mat_T = rot_mat_T.T

# 右乘,逆时针旋转theta角度
self.tensor[:, :3] = self.tensor[:, :3] @ rot_mat_T

lidar_boxes中的rotate

顺时针旋转theta角

# 逆时针旋转theta角
if angle.numel() == 1:
            rot_sin = torch.sin(angle)
            rot_cos = torch.cos(angle)
            rot_mat_T = self.tensor.new_tensor(
                [[rot_cos, -rot_sin, 0], [rot_sin, rot_cos, 0], [0, 0, 1]]
            )
# 右乘,顺时针旋转theta角
self.tensor[:, :3] = self.tensor[:, :3] @ rot_mat_T

# 顺时针旋转theta角,同lidar_box,yaw朝向的定义
self.tensor[:, 6] += angle

inno,代码的简单更改

顺时针旋转theta角,匹配到yaw的定义变为-angle

# 逆时针旋转theta角
if angle.numel() == 1:
            rot_sin = torch.sin(angle)
            rot_cos = torch.cos(angle)
            rot_mat_T = self.tensor.new_tensor(
                [[rot_cos, -rot_sin, 0], [rot_sin, rot_cos, 0], [0, 0, 1]]
            )
# 右乘,顺时针旋转theta角
self.tensor[:, :3] = self.tensor[:, :3] @ rot_mat_T

# 顺时针旋转theta角,也就是相当于逆时针旋转-theta角,符合inno_gt_box yaw朝向的定义
#self.tensor[:, 6] += angle
self.tensor[:, 6] -= angle

2. 图像的3D增强

对应的图像进行顺序变换,resize => crop => flip => rotate,用的是PIL.img的函数

对应的矩阵也顺序的进行相乘,并进行保留,用于后续的变换对齐

图像的默认的坐标,X朝右,Y朝下,右手系。

先resize

然后处理crop,也就是平移

然后处理flip,

然后处理rotate,

PIL.img中的旋转,默认是逆时针旋转theta角,但是图像坐标系中默认是顺时针旋转,因此这里旋转矩阵和default的进行了一个-theta的变换。

还有一个要处理的是,旋转的中心不是坐标原点,而是图像中心点,因此也要特殊处理:
先平移b,然后逆时针旋转theta,然后反向平移-b,右乘相对。

    def img_transform(
        self, img, rotation, translation, resize, resize_dims, crop, flip, rotate
    ):
        # adjust image
        img = img.resize(resize_dims)
        img = img.crop(crop)
        if flip:
            img = img.transpose(method=Image.FLIP_LEFT_RIGHT)
        img = img.rotate(rotate)

        # post-homography transformation
        rotation *= resize
        translation -= torch.Tensor(crop[:2])
        if flip:
            A = torch.Tensor([[-1, 0], [0, 1]])
            b = torch.Tensor([crop[2] - crop[0], 0])
            rotation = A.matmul(rotation)
            translation = A.matmul(translation) + b
        theta = rotate / 180 * np.pi
        A = torch.Tensor(
            [
                [np.cos(theta), np.sin(theta)],
                [-np.sin(theta), np.cos(theta)],
            ]
        )
        b = torch.Tensor([crop[2] - crop[0], crop[3] - crop[1]]) / 2
        b = A.matmul(-b) + b
        rotation = A.matmul(rotation)
        translation = A.matmul(translation) + b

        return img, rotation, translation

    def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
        imgs = data["img"]
        new_imgs = []
        transforms = []
        for img in imgs:
            resize, resize_dims, crop, flip, rotate = self.sample_augmentation(data)
            post_rot = torch.eye(2)
            post_tran = torch.zeros(2)
            new_img, rotation, translation = self.img_transform(
                img,
                post_rot,
                post_tran,
                resize=resize,
                resize_dims=resize_dims,
                crop=crop,
                flip=flip,
                rotate=rotate,
            )
            transform = torch.eye(4)
            transform[:2, :2] = rotation
            transform[:2, 3] = translation
            new_imgs.append(new_img)
            transforms.append(transform.numpy())
        data["img"] = new_imgs
        # update the calibration matrices
        data["img_aug_matrix"] = transforms
        return data

3. LSS的坐标变化的使用

3.1 lidar投影到image,获取image的depth

此时的点云是经过点云的数据增强的,首先进行lidar_aug的逆变换,恢复原始点云;

然后利用lidar2camera以及intrinsic,进行lidar to image的投影;

然后按照image_aug进行变换,这样就能得到点云xyz在最终image上的xyz,也就是uvd;

cur_coords = points[b][:, :3]
cur_img_aug_matrix = img_aug_matrix[b]
cur_lidar_aug_matrix = lidar_aug_matrix[b]
cur_lidar2image = lidar2image[b]

# inverse aug
cur_coords -= cur_lidar_aug_matrix[:3, 3]
cur_coords = torch.inverse(cur_lidar_aug_matrix[:3, :3]).matmul(
    cur_coords.transpose(1, 0)
)
# lidar2image
cur_coords = cur_lidar2image[:, :3, :3].matmul(cur_coords)
cur_coords += cur_lidar2image[:, :3, 3].reshape(-1, 3, 1)
# get 2d coords
dist = cur_coords[:, 2, :]
cur_coords[:, 2, :] = torch.clamp(cur_coords[:, 2, :], 1e-5, 1e5)
cur_coords[:, :2, :] /= cur_coords[:, 2:3, :]

# imgaug
cur_coords = cur_img_aug_matrix[:, :3, :3].matmul(cur_coords)
cur_coords += cur_img_aug_matrix[:, :3, 3].reshape(-1, 3, 1)
cur_coords = cur_coords[:, :2, :].transpose(1, 2)

# normalize coords for grid sample
cur_coords = cur_coords[..., [1, 0]]

3.2 结合image的depth,以及前序的image的feature,预估出深度

通过dtransform,一系列的卷积,进行深度的特征提取,同时resize到特征图的尺度;

和前序的特征concate到一起;

通过depthnet,预估出新的特征,以及深度对应的权重;

def get_cam_feats(self, x, d):
        B, N, C, fH, fW = x.shape

        d = d.view(B * N, *d.shape[2:])
        x = x.view(B * N, C, fH, fW)

        d = self.dtransform(d)
        x = torch.cat([d, x], dim=1)
        x = self.depthnet(x)

        depth = x[:, : self.D].softmax(dim=1)
        x = depth.unsqueeze(1) * x[:, self.D : (self.D + self.C)].unsqueeze(2)

        x = x.view(B, N, self.C, self.D, fH, fW)
        x = x.permute(0, 1, 3, 4, 5, 2)
        return x

3.3 生成视锥,以及视锥点云到bev坐标的投影关系

视椎生成,image下,每个pixel对应一系列的depth,也就是相当很多的点云;

def create_frustum(self):
        iH, iW = self.image_size
        fH, fW = self.feature_size

        ds = (
            torch.arange(*self.dbound, dtype=torch.float)
            .view(-1, 1, 1)
            .expand(-1, fH, fW)
        )
        D, _, _ = ds.shape

        xs = (
            torch.linspace(0, iW - 1, fW, dtype=torch.float)
            .view(1, 1, fW)
            .expand(D, fH, fW)
        )
        ys = (
            torch.linspace(0, iH - 1, fH, dtype=torch.float)
            .view(1, fH, 1)
            .expand(D, fH, fW)
        )

        frustum = torch.stack((xs, ys, ds), -1)
        return nn.Parameter(frustum, requires_grad=False)

首先对image_aug进行逆变换;

image u,v坐标变成,xyz坐标;

利用intrinsic以及外参,变化到lidar坐标系下;

然后进行lidar的数据增强操作;

def get_geometry(
        self,
        camera2lidar_rots,
        camera2lidar_trans,
        intrins,
        post_rots,
        post_trans,
        **kwargs,
    ):
        B, N, _ = camera2lidar_trans.shape

        # undo post-transformation
        # B x N x D x H x W x 3
        points = self.frustum - post_trans.view(B, N, 1, 1, 1, 3)
        points = (
            torch.inverse(post_rots)
            .view(B, N, 1, 1, 1, 3, 3)
            .matmul(points.unsqueeze(-1))
        )
        # cam_to_lidar
        points = torch.cat(
            (
                points[:, :, :, :, :, :2] * points[:, :, :, :, :, 2:3],
                points[:, :, :, :, :, 2:3],
            ),
            5,
        )
        combine = camera2lidar_rots.float().matmul(torch.inverse(intrins))
        points = combine.view(B, N, 1, 1, 1, 3, 3).matmul(points).squeeze(-1)
        points += camera2lidar_trans.view(B, N, 1, 1, 1, 3)

        if "extra_rots" in kwargs:
            extra_rots = kwargs["extra_rots"]
            points = (
                extra_rots.view(B, 1, 1, 1, 1, 3, 3)
                .repeat(1, N, 1, 1, 1, 1, 1)
                .matmul(points.unsqueeze(-1))
                .squeeze(-1)
            )
        if "extra_trans" in kwargs:
            extra_trans = kwargs["extra_trans"]
            points += extra_trans.view(B, 1, 1, 1, 1, 3).repeat(1, N, 1, 1, 1, 1)

        return points

3.4 提取feature到bev pillar中

此处调用了封装的bev_pool函数。

def bev_pool(self, geom_feats, x):
        B, N, D, H, W, C = x.shape
        Nprime = B * N * D * H * W

        # flatten x
        x = x.reshape(Nprime, C)

        # flatten indices
        geom_feats = ((geom_feats - (self.bx - self.dx / 2.0)) / self.dx).long()
        geom_feats = geom_feats.view(Nprime, 3)
        batch_ix = torch.cat(
            [
                torch.full([Nprime // B, 1], ix, device=x.device, dtype=torch.long)
                for ix in range(B)
            ]
        )
        geom_feats = torch.cat((geom_feats, batch_ix), 1)

        # filter out points that are outside box
        kept = (
            (geom_feats[:, 0] >= 0)
            & (geom_feats[:, 0] < self.nx[0])
            & (geom_feats[:, 1] >= 0)
            & (geom_feats[:, 1] < self.nx[1])
            & (geom_feats[:, 2] >= 0)
            & (geom_feats[:, 2] < self.nx[2])
        )
        x = x[kept]
        geom_feats = geom_feats[kept]
        
        x = bev_pool(x, geom_feats, B, self.nx[2], self.nx[0], self.nx[1])

        # collapse Z
        final = torch.cat(x.unbind(dim=2), 1)

        return final

猜你喜欢

转载自blog.csdn.net/huang_victor/article/details/130558625
今日推荐