Detailed Explanation of PointNet++ Classification and Segmentation

foreword

       PointNet++ is a deep neural network for the classification and segmentation tasks of irregularly shaped point cloud data. Compared with traditional grid-based 3D data representation methods, point cloud data is easier to acquire and process. Another advantage of PointNet++ is that it introduces a multi-scale hierarchy that can handle more complex point cloud data. Compared with the first version of PointNet, the author put forward many new ideas and achieved very good results.

Problems with the PointNet algorithm

(1) There are too many points in a point cloud image, which will cause excessive calculation and reduce the speed of the algorithm. How to solve it?

(2) How to divide the point cloud into different regions and obtain the local features of different regions?

(3) How to solve this problem when the point cloud is uneven?

With these questions in mind, we next started to solve these problems through papers and source code.

classification task

8a4d9f3376a45c32f22c60ad476248c8.jpeg

Hierarchical extraction features set abstraction

This module mainly consists of 3 parts:

1. Sampling layer (sample layer): Some relatively important points are extracted from the dense point cloud as the center point, that is, FPS (farthest point sampling) farthest point sampling method, also in order to solve the first problem in this article.

2. Group layer (group layer): find the nearest K points near the center point to form a local points region. This operation is a bit like image convolution, forming a convolution image for easy feature extraction. Solve the second problem.

3. Feature extraction layer (pointnet layer): feature extraction layer. Extract features for each local points region.

The FPS process is as follows:

(1) Randomly select a point as the initial sampling point;

(2) Calculate the distance between each point in the unselected sampling point set and the selected sampling point set, and add the point with the largest distance to the selected sampling point set,

(3) Calculate the distance according to the new sampling points, and iterate continuously until the target number of sampling points is obtained.

As shown in the figure below, 5 points are selected from all the points.

da1a282f9e73683da21c50df354f1151.png

The code is implemented as follows

def farthest_point_sample(xyz, npoint):     
    """
    Input:
        xyz: pointcloud data, [B, N, 3]
        npoint: number of samples
    Return:
        centroids: sampled pointcloud index, [B, npoint]
    """
    device = xyz.device
    B, N, C = xyz.shape
    centroids = torch.zeros(B, npoint, dtype=torch.long).to(device)         #[b,npoint]    #npoint是要从很多的点中筛选出那么多
    distance = torch.ones(B, N).to(device) * 1e10                           #[b,N]         #N指原来有N个点
    farthest = torch.randint(0, N, (B,), dtype=torch.long).to(device)       #[b]           在0-N中随机生成了B个点(随机选了个点作为初始点,并且用的索引
    batch_indices = torch.arange(B, dtype=torch.long).to(device)            #[b]            0-b
    for i in range(npoint):
        centroids[:, i] = farthest
        centroid = xyz[batch_indices, farthest, :].view(B, 1, 3)           #获得索引对应的位置的坐标
        dist = torch.sum((xyz - centroid) ** 2, -1)                        #计算所有坐标和目前这个点的距离
        mask = dist < distance                                             #距离符合要求的
        distance[mask] = dist[mask]                                        #将符合要求的距离都放入
        farthest = torch.max(distance, -1)[1]                              #最远距离对应的索引  [b]
    return centroids                                                       #最终输出筛选的[n,npoint]个点的位置

The packet layer flow is as follows:

(1) Obtain the corresponding center point after screening according to FPS

(2) Use each center point in the origin set to filter out the number of points needed nearby according to the distance, and form a new point set centered on each FPS point

(3) The new point set will perform an operation similar to coordinate normalization to form 3 new features and then combine with the original features of each point to form new features before feature extraction.

The simplified diagram is as follows:

c44b141189eccb1afe7c9682cd5ef7a2.png

   Figure 1  

3346b3e3963e5708250d5134402e3d21.png

Figure II

       In Figure 1, the red dot is the center point of the FPS result, and the black dots are some initial points. The green points in Figure 2 are the points filtered according to the distance, and these points and the red points will form a series of point sets.

       In the grouping layer, the author proposes three schemes: SSG (single scale grouping), MSG (multi-scale grouping) multi-scale, and MRG (multi-resolution grouping) multi-resolution. In fact, multiple sampling groups are performed with different radii or different resolutions. It is also to solve the third problem in this article.

SSG: It is equivalent to group sampling with only one radius

9f36152f36d59afa2788301404fd4c5a.png

MSG: It is equivalent to sampling multiple radius groups at the same resolution, and then combining the point sets.

aa5ae0262f303f506022443365ce5abf.pngccc4a6e3d0932305999fc9383e037850.pngaa4484154577d1dd1e946f27ac613522.png

MRI

0e7daba91bfd9d4b254cc5b2d26d4526.png

The code is implemented as follows

def query_ball_point(radius, nsample, xyz, new_xyz):
    """
    Input:
        radius: local region radius
        nsample: max sample number in local region   
        xyz: all points, [B, N, 3]
        new_xyz: query points, [B, S, 3]
    Return:
        group_idx: grouped points index, [B, S, nsample]
    """
    """
    1.预设搜索区域的半径R与子区域的点数K
    2.上面提取出来了 s 个点,作为 s个centriods。以这 s个点为球心,画半径为R的球体(叫做query ball,也就是搜索区域)。
    3.在每个以centriods的球心的球体内搜索离centriods最近的的点(按照距离从小到大排序,找到K个点)。
      如果query ball的点数量大于规模nsample,那么直接取前nsample个作为子区域;如果小于,那么直接对某个点重采样(此处直接复制了第一个点),凑够规模nsample
    """


    device = xyz.device
    B, N, C = xyz.shape
    _, S, _ = new_xyz.shape
    group_idx = torch.arange(N, dtype=torch.long).to(device).view(1, 1, N).repeat([B, S, 1])    #[2,512,1024]
    sqrdists = square_distance(new_xyz, xyz)    #获得采样后点集与原先点集的距离[2,512,1024]
    group_idx[sqrdists > radius ** 2] = N       #将距离比半径大的,将此处置为1024
    group_idx = group_idx.sort(dim=-1)[0][:, :, :nsample]  #因为是0-1023的索引,不符合要求的变味了1024,再对索引排序,获得前nsample个  [b,s,nsample]
    group_first = group_idx[:, :, 0].view(B, S, 1).repeat([1, 1, nsample])  #拿了符合要求的第一个索引点索引,也就是中心点,并且复制了nsample次     [b,s,nsample]
    mask = group_idx == N                #查看那前nsample中有没有不符合要求的
    group_idx[mask] = group_first[mask]   #其中将不符合要求的点,全部换成符合要求的第一个点
    return group_idx   


  def sample_and_group(npoint, radius, nsample, xyz, points, returnfps=False):
    """
    Input:
        npoint:
        radius:
        nsample:
        xyz: input points position data, [B, N, 3]
        points: input points data, [B, N, D]
    Return:
        new_xyz: sampled points position data, [B, npoint, nsample, 3]
        new_points: sampled points data, [B, npoint, nsample, 3+D]
    """
    B, N, C = xyz.shape
    S = npoint
    fps_idx = farthest_point_sample(xyz, npoint) # [B, npoint]        #FPS最远点采样算法  获得需要点的索引
    new_xyz = index_points(xyz, fps_idx)          # [B, npoint,c]                        #将对应索引的点拿出来
    idx = query_ball_point(radius, nsample, xyz, new_xyz)       #[b,npoint,nsample]         #进行query_ball处理,类似于卷积操作,找到new_xyz附近原xyz的nsample个点,返回对应的索引[b,npoint,nsample]
    grouped_xyz = index_points(xyz, idx) # [B, npoint, nsample, C]       #将对应索引的点拿出来
    grouped_xyz_norm = grouped_xyz - new_xyz.view(B, S, 1, C)   #每个点减去自己半径内中心的那个点进行归一化[B, npoint, nsample, C]


    if points is not None:   #points即原来点就存才的一些特征
        grouped_points = index_points(points, idx)       #将每个区域原先的特征拿出来   [b,npoint,nsample,D]
        new_points = torch.cat([grouped_xyz_norm, grouped_points], dim=-1) # [B, npoint, nsample, C+D]   #将归一化数据和原先的特征结合
    else:
        new_points = grouped_xyz_norm                    #如果原先没有特征的,那么数据就是归一化后的点
    if returnfps:
        return new_xyz, new_points, grouped_xyz, fps_idx
    else:
        return new_xyz, new_points

feature extraction layer

       This layer is some basic convolution pooling operations. Finally, those point features selected by FPS will increase.

e1e2cabd55575cc3a06a62d25bf6685a.png45843d2325cec2c04e0775a6c64bbbfd.png

       Repeat the set abstraction layer several times, and finally connect some fully connected networks to classify the point cloud.

split task

4b41ac3b32c82a46e2671fbf669cebaf.png

       The feature extractor of the segmentation task is the same as that of the classification task. Next, we will mainly talk about the up-sampling link. The author proposes a hierarchical feature propagation (Feature Propagation) strategy based on distance interpolation.

The general process is as follows:

(1) Calculate the weight of the reverse distance weighted average, as shown in the figure below: the red point is the point after FPS feature extraction, in fact, the number of features of the red point will be more than that of the black point. And upsampling is to make these black spots also produce matching features.

2251d7ded991ae40f9bae7824f5ac153.png

Weight calculation: For each black point, find the nearest 3 red points, and then assign features to the black point according to the distance weight at each red point. Then each black point will generate new features. The red dot also needs to do the same. That is, all points in the above figure will generate new features.

7af870025b17fa75334da7b853eed9ae.pngbc119e8b1403fadf124ec834ca6b05b2.png718fc4bece948c7ea45bb123e7bd9bca.pngd361bd6ea005831c92ee017c4135b41c.png

(2) The new features generated are cat-operated with the features of the previous layer, and then the feature fusion is completed through convolution. One-step upsampling ends.

The code is implemented as follows

class PointNetFeaturePropagation(nn.Module):  #上采样
    def __init__(self, in_channel, mlp):
        super(PointNetFeaturePropagation, self).__init__()
        self.mlp_convs = nn.ModuleList()
        self.mlp_bns = nn.ModuleList()
        last_channel = in_channel
        for out_channel in mlp:
            self.mlp_convs.append(nn.Conv1d(last_channel, out_channel, 1))
            self.mlp_bns.append(nn.BatchNorm1d(out_channel))
            last_channel = out_channel


    def forward(self, xyz1, xyz2, points1, points2):
        """
        Input:
            xyz1: input points position data, [B, C, N]      上一层的
            xyz2: sampled input points position data, [B, C, S]  现在层的
            points1: input points data, [B, D1, N]            上一层的
            points2: input points data, [B, D2, S]            现在层的
        Return:
            new_points: upsampled points data, [B, D', N]
        """
        xyz1 = xyz1.permute(0, 2, 1)
        xyz2 = xyz2.permute(0, 2, 1)


        points2 = points2.permute(0, 2, 1)
        B, N, C = xyz1.shape
        _, S, _ = xyz2.shape


        if S == 1:
            interpolated_points = points2.repeat(1, N, 1)
        else:
            dists = square_distance(xyz1, xyz2)    #计算两点集各点之间距离[b,N,S]
            dists, idx = dists.sort(dim=-1)        #排序
            dists, idx = dists[:, :, :3], idx[:, :, :3]  # [B, N, 3]  获得距离最近的3点


            dist_recip = 1.0 / (dists + 1e-8)
            norm = torch.sum(dist_recip, dim=2, keepdim=True)
            weight = dist_recip / norm              #根据距离设置权重
            interpolated_points = torch.sum(index_points(points2, idx) * weight.view(B, N, 3, 1), dim=2)   #[B,N,D2]


        if points1 is not None:
            points1 = points1.permute(0, 2, 1)
            new_points = torch.cat([points1, interpolated_points], dim=-1)   #上采样cat  [B,N,D1+D2]
        else:
            new_points = interpolated_points


        new_points = new_points.permute(0, 2, 1)                             #[B,D1+D2,N]
        for i, conv in enumerate(self.mlp_convs):
            bn = self.mlp_bns[i]
            new_points = F.relu(bn(conv(new_points)))
        return new_points         #[B, D', N]

       So far, the article analysis about PointNet++ is over! If there is any misinterpretation, welcome to criticize and correct, let's make progress together!

Guess you like

Origin blog.csdn.net/weixin_41202834/article/details/130313302