PointNet++ papers and detailed code

This article is also based on the article of Zhihu Liu Xinchen

https://zhuanlan.zhihu.com/p/266324173

1- motivation

PointNet++ is an improvement on PointNet.
If you want to understand PointNet++, you must first understand what the principle of PointNet looks
like. For the introduction of PointNet, you can read my article
PointNet paper and code detailed analysis

Because PointNet only uses MLP and max pooling, it has no ability to capture local structures, so it has limited capabilities in detail processing and generalization to complex scenes.

Several problems of PointNet I summarized:
1. Point-wise MLP, only for each point representation, the ability to integrate local structure information is too weak --> Improvement of PointNet++: sampling and grouping integrate local neighborhoods

2. The global feature is directly obtained by max pooling. Whether it is for classification or segmentation tasks, it will cause huge information loss --> PointNet++ improvement: hierarchical feature learning framework, downsampling through multiple set abstractions step by step to obtain different scales Different levels of local-global feature

3. The global feature of the segmentation task is directly copied and spliced ​​with the local feature, and the ability to generate discriminative features (feature discrimination) is limited --> Improvement of PointNet++: the encoder-decoder structure is designed for the segmentation task, downsampling and then upsampling, using The skip connection splices the local-global feature of the corresponding layer

2- solution

insert image description hereHierarchical point set feature learning (hierarchical feature learning), the core of PointNet++

PointNet++'s network is roughly an encoder-decoder structure

The encoder is a down-sampling process. It implements multi-level down-sampling through multiple set abstraction structures to obtain point-wise features of different scales. The last set abstraction output can be considered as a global feature. Among them, set abstraction consists of three modules: sampling, grouping, and pointnet.

The decoder is applied differently according to classification and segmentation. The classification task decoder is relatively simple and will not be introduced. The segmentation task decoder is an upsampling process. While upsampling is realized through reverse interpolation and skip connection, local+global point-wise features can also be obtained, so that the final representation can be discriminative (discriminative)

So before looking down, we'd better have two questions:
1. How is the PointNet++ downsampling process implemented? How does /PointNet++ represent global features? (Pay attention to set abstraction, sampling layer, grouping layer, pointnet layer)
2. How is the upsampling process of PointNet++ used for segmentation tasks implemented? How does /PointNet++ represent point-wise features for segmentation tasks? (Focus on reverse interpolation, skip connection)

Below I will introduce the code in detail how the PointNet++ network propagates forward (that is, what the network is doing), which is very important for understanding the design of the network.
Declaration: d represents the coordinate space dimension. C represents the feature space dimension.

2-1 encoder

On the basis of PointNet, the structure of hierarchical feature learning framework is added. This multi-level structure consists of set abstraction layers.
At each level of set abstraction, the point set will be processed and abstracted to generate a smaller point set, which can be understood as a downsampling characterization process, please refer to the left half of the above figure.
set abstraction consists of three parts (the code is pasted below):

def pointnet_sa_module(xyz, points, npoint, radius, nsample, mlp, mlp2, group_all, is_training, bn_decay, scope, bn=True, pooling='max', knn=False, use_xyz=True, use_nchw=False):
    ''' PointNet Set Abstraction (SA) Module
        Input:
            xyz: (batch_size, ndataset, 3) TF tensor
            points: (batch_size, ndataset, channel) TF tensor
            npoint: int32 -- #points sampled in farthest point sampling
            radius: float32 -- search radius in local region
            nsample: int32 -- how many points in each local region
            mlp: list of int32 -- output size for MLP on each point
            mlp2: list of int32 -- output size for MLP on each region
            group_all: bool -- group all points into one PC if set true, OVERRIDE
                npoint, radius and nsample settings
            use_xyz: bool, if True concat XYZ with local point features, otherwise just use point features
            use_nchw: bool, if True, use NCHW data format for conv2d, which is usually faster than NHWC format
        Return:
            new_xyz: (batch_size, npoint, 3) TF tensor
            new_points: (batch_size, npoint, mlp[-1] or mlp2[-1]) TF tensor
            idx: (batch_size, npoint, nsample) int32 -- indices for local regions
    '''
    data_format = 'NCHW' if use_nchw else 'NHWC'
    with tf.variable_scope(scope) as sc:
        # Sample and Grouping
        if group_all:
            nsample = xyz.get_shape()[1].value
            new_xyz, new_points, idx, grouped_xyz = sample_and_group_all(xyz, points, use_xyz)
        else:
            new_xyz, new_points, idx, grouped_xyz = sample_and_group(npoint, radius, nsample, xyz, points, knn, use_xyz)
        # Point Feature Embedding
        if use_nchw: new_points = tf.transpose(new_points, [0,3,1,2])
        for i, num_out_channel in enumerate(mlp):
            new_points = tf_util.conv2d(new_points, num_out_channel, [1,1],
                                        padding='VALID', stride=[1,1],
                                        bn=bn, is_training=is_training,
                                        scope='conv%d'%(i), bn_decay=bn_decay,
                                        data_format=data_format) 
        if use_nchw: new_points = tf.transpose(new_points, [0,2,3,1])
        # Pooling in Local Regions
        if pooling=='max':
            new_points = tf.reduce_max(new_points, axis=[2], keep_dims=True, name='maxpool')
        elif pooling=='avg':
            new_points = tf.reduce_mean(new_points, axis=[2], keep_dims=True, name='avgpool')
        elif pooling=='weighted_avg':
            with tf.variable_scope('weighted_avg'):
                dists = tf.norm(grouped_xyz,axis=-1,ord=2,keep_dims=True)
                exp_dists = tf.exp(-dists * 5)
                weights = exp_dists/tf.reduce_sum(exp_dists,axis=2,keep_dims=True) # (batch_size, npoint, nsample, 1)
                new_points *= weights # (batch_size, npoint, nsample, mlp[-1])
                new_points = tf.reduce_sum(new_points, axis=2, keep_dims=True)
        elif pooling=='max_and_avg':
            max_points = tf.reduce_max(new_points, axis=[2], keep_dims=True, name='maxpool')
            avg_points = tf.reduce_mean(new_points, axis=[2], keep_dims=True, name='avgpool')
            new_points = tf.concat([avg_points, max_points], axis=-1)
        # [Optional] Further Processing 
        if mlp2 is not None:
            if use_nchw: new_points = tf.transpose(new_points, [0,3,1,2])
            for i, num_out_channel in enumerate(mlp2):
                new_points = tf_util.conv2d(new_points, num_out_channel, [1,1],
                                            padding='VALID', stride=[1,1],
                                            bn=bn, is_training=is_training,
                                            scope='conv_post_%d'%(i), bn_decay=bn_decay,
                                            data_format=data_format) 
            if use_nchw: new_points = tf.transpose(new_points, [0,2,3,1])
        new_points = tf.squeeze(new_points, [2]) # (batch_size, npoints, mlp2[-1])
        return new_xyz, new_points, idx

2-1-1 sampling layer

Downsample the point set using FPS (Furthest Point Sampling), reducing the input point set from size N1 to a smaller size N2. FPS can be understood as making the sampling points as far away as possible. The advantage of this sampling is that the downsampling results will be more uniform.

FPS (Farthest Point Sampling) - farthest point sampling - algorithm flow:

Assuming that the number of points is N, N = {P1,P2,...,Pn}, the set of points after sampling is S, initially S = {}, sampling c points 1. Randomly select one of the N
points Point Pk1, put it into S, S = {Pk1}
2. Calculate the distance from the remaining n-1 points to the point set S, and get n-1 distances in total, select the point Pk2 farthest from the point set S, and put it into S , S = {Pk1, Pk2}
3. Calculate the distance from the remaining n-2 points to the point set S. For one of the n-2 points, there are two points in the point set S at this time, we calculate Get two distances, choose the smallest distance as the distance from the point to the set, so there are n-2 distances from n-2 points to the set, we choose the farthest Pk3, put it into S, S = { Pk1, Pk2, Pk3}
...
cycle repeatedly until S = {Pk1, Pk2, ... , Pkc}

Advantages of FPS:
The advantage of farthest point sampling is that it can cover all points in space as much as possible.
Disadvantages:
High computational complexity and time-consuming

class FarthestSampler:
    def __init__(self):
        pass
    def _calc_distances(self, p0, points):
        return ((p0 - points) ** 2).sum(axis=1)
    def __call__(self, pts, k):
        farthest_pts = np.zeros((k, 3), dtype=np.float32)
        farthest_pts[0] = pts[np.random.randint(len(pts))]
        distances = self._calc_distances(farthest_pts[0], pts)
        for i in range(1, k):
            farthest_pts[i] = pts[np.argmax(distances)]
            distances = np.minimum(
                distances, self._calc_distances(farthest_pts[i], pts))
        return farthest_pts

The input size is B * N * (d+C), where B represents the batch size, N represents the number of points in the point set, d represents the coordinate dimension of the point, and C represents other features of the point (such as normal vector) dimension. Generally d=3, c=0.

The output size is B * N1 * (d+C), N1<N, because this is a downsampling process.

The specific implementation of sampling and grouping is written in a function:

def sample_and_group(npoint, radius, nsample, xyz, points, knn=False, use_xyz=True):
    '''
    Input:
        npoint: int32
        radius: float32
        nsample: int32
        xyz: (batch_size, ndataset, 3) TF tensor
        points: (batch_size, ndataset, channel) TF tensor, if None will just use xyz as points
        knn: bool, if True use kNN instead of radius search
        use_xyz: bool, if True concat XYZ with local point features, otherwise just use point features
    Output:
        new_xyz: (batch_size, npoint, 3) TF tensor
        new_points: (batch_size, npoint, nsample, 3+channel) TF tensor
        idx: (batch_size, npoint, nsample) TF tensor, indices of local points as in ndataset points
        grouped_xyz: (batch_size, npoint, nsample, 3) TF tensor, normalized point XYZs
            (subtracted by seed point XYZ) in local regions
    '''
    new_xyz = gather_point(xyz, farthest_point_sample(npoint, xyz)) # (batch_size, npoint, 3)
    if knn:
        _,idx = knn_point(nsample, xyz, new_xyz)
    else:
        idx, pts_cnt = query_ball_point(radius, nsample, xyz, new_xyz)
    grouped_xyz = group_point(xyz, idx) # (batch_size, npoint, nsample, 3)
    grouped_xyz -= tf.tile(tf.expand_dims(new_xyz, 2), [1,1,nsample,1]) # translation normalization
    if points is not None:
        grouped_points = group_point(points, idx) # (batch_size, npoint, nsample, channel)
        if use_xyz:
            new_points = tf.concat([grouped_xyz, grouped_points], axis=-1) # (batch_size, npoint, nample, 3+channel)
        else:
            new_points = grouped_points
    else:
        new_points = grouped_xyz
    return new_xyz, new_points, idx, grouped_xyz

The part corresponding to sampling is:

new_xyz = gather_point(xyz, farthest_point_sample(npoint, xyz)) # (batch_size, npoint, 3)

xyz is the point cloud of B * N * 3, and npoint is the scale of downsampled points. Note: The FPS of PointNet++ is done in the coordinate space, not in the feature space. This is critical because FPS itself is not differentiable and cannot compute gradient backpropagation.

In the spirit of digging into the root of the problem, let's take a look at what farthest_point_sample and gather_point are doing

The input and output of farthest_point_sample is very clear. The output is the index of the downsampling point in inp (input point), so it is a tensor of type B*N1 int32

def farthest_point_sample(npoint,inp):
    '''
input:
    int32
    batch_size * ndataset * 3   float32
returns:
    batch_size * npoint         int32
    '''
    return sampling_module.farthest_point_sample(inp, npoint)

The function of gather_point is to convert the index output above into a real point cloud

def gather_point(inp,idx):
    '''
input:
    batch_size * ndataset * 3   float32
    batch_size * npoints        int32
returns:
    batch_size * npoints * 3    float32
    '''
    return sampling_module.gather_point(inp,idx)

grouping layer

The sampling process in the previous step is to reduce N * (d+C) to N1 * (d+C) (here, it is convenient to consider a single point cloud without considering batch), which can actually be understood as being in N points Select N1 key points.

Then the purpose of grouping in this step is to take each key point as the center, find its neighbors with a fixed size (let the size be K), and form a local neighborhood (patch) together. That is, N1 local neighborhoods will be generated, and the output size will be N1 * K * (d+C).

if knn:
    _,idx = knn_point(nsample, xyz, new_xyz)
else:
    idx, pts_cnt = query_ball_point(radius, nsample, xyz, new_xyz)
    grouped_xyz = group_point(xyz, idx) # (batch_size, npoint, nsample, 3)

2 points to note:

1. The process of finding neighbors is also carried out in the coordinate space (that is, the input and output dimensions of the above code are all d, without C, which is spliced ​​​​in the following code), not the feature space.
2. There are two ways to find the neighborhood: KNN and query ball point.

KNN is K nearest neighbor sampling: find the closest point in K coordinate space.
The latter query ball point ball radius query is to delineate a certain radius, and find the points within the radius ball as adjacent points.

For spherical radius query: How does the query ball point ensure that the number of sampling points is the same for each local neighborhood?
In fact, if the number of points in the query ball is greater than the size K, then directly take the first K as the local neighborhood; if it is smaller, then directly resample a certain point (copy the point closest to the center point), and get enough size K

The difference between KNN and query ball : (from the original text) Compared with kNN, ball query's local neighborhood guarantees a fixed region scale thus making local region feature more generalizable across space, which is preferred for tasks requiring local pattern recognition (eg semantic point labeling) That is, query ball is more suitable for applications in local/detail recognition, such as local segmentation.

There are also experiments in the supplementary material to compare KNN and query ball:
insert image description here
The rest of the sample_and_group code:

Both sample and group operations are performed in the coordinate space, so if there is still feature space information (that is, point-wise feature), you can splice it with the coordinate space here to form a new point-wise feature, which is ready to be sent to the following unit point for feature learning.

if points is not None:
    grouped_points = group_point(points, idx) # (batch_size, npoint, nsample, channel)
    if use_xyz:
        new_points = tf.concat([grouped_xyz, grouped_points], axis=-1) # (batch_size, npoint, nample, 3+channel)
    else:
        new_points = grouped_points
else:
    new_points = grouped_xyz

2-1-3 PointNet layer

Use PointNet to characterize the above results Input
: B * N * K * (d+C)
Output: B * N * (d+C1)

The following code is mainly divided into three parts:
1. point feature embedding
2. pooling in local regions
3. further processing

For the first part point feature embedding :
the input here is B * N * K * (d+C), which can be compared to an image with a batch size of B, a width and height of N * K, and a number of channels of d+C, which is actually
1 *1 Convolution, do not change the size of the feature map, only change the number of channels, increase the number of channels, and realize
the output of the so-called "embedding" part is B * N * K * C1

For the second part of pooling in local regions : pooling, the output of this part of
pooling for each local neighborhood is B * N * 1 * C1

For the third part further processing :
MLP is done on the pooled result, which is also a simple [1×1] convolution. This part is not set to do in the actual experiment PointNet++

# Point Feature Embedding
if use_nchw: new_points = tf.transpose(new_points, [0,3,1,2])
for i, num_out_channel in enumerate(mlp):
    new_points = tf_util.conv2d(new_points, num_out_channel, [1,1],
                                padding='VALID', stride=[1,1],
                                bn=bn, is_training=is_training,
                                scope='conv%d'%(i), bn_decay=bn_decay,
                                data_format=data_format) 
if use_nchw: new_points = tf.transpose(new_points, [0,2,3,1])

# Pooling in Local Regions
if pooling=='max':
    new_points = tf.reduce_max(new_points, axis=[2], keep_dims=True, name='maxpool')
elif pooling=='avg':
    new_points = tf.reduce_mean(new_points, axis=[2], keep_dims=True, name='avgpool')
elif pooling=='weighted_avg':
    with tf.variable_scope('weighted_avg'):
        dists = tf.norm(grouped_xyz,axis=-1,ord=2,keep_dims=True)
        exp_dists = tf.exp(-dists * 5)
        weights = exp_dists/tf.reduce_sum(exp_dists,axis=2,keep_dims=True) # (batch_size, npoint, nsample, 1)
        new_points *= weights # (batch_size, npoint, nsample, mlp[-1])
        new_points = tf.reduce_sum(new_points, axis=2, keep_dims=True)
elif pooling=='max_and_avg':
    max_points = tf.reduce_max(new_points, axis=[2], keep_dims=True, name='maxpool')
    avg_points = tf.reduce_mean(new_points, axis=[2], keep_dims=True, name='avgpool')
    new_points = tf.concat([avg_points, max_points], axis=-1)

# [Optional] Further Processing 
if mlp2 is not None:
    if use_nchw: new_points = tf.transpose(new_points, [0,3,1,2])
    for i, num_out_channel in enumerate(mlp2):
        new_points = tf_util.conv2d(new_points, num_out_channel, [1,1],
                                    padding='VALID', stride=[1,1],
                                    bn=bn, is_training=is_training,
                                    scope='conv_post_%d'%(i), bn_decay=bn_decay,
                                    data_format=data_format) 
    if use_nchw: new_points = tf.transpose(new_points, [0,2,3,1])

2-1-4 There is another question about the encoder

Pointnet++ is actually a representation of local neighborhoods.

Then we have to face a challenge: non-uniform sampling density, that is, training in the local neighborhood of the sparse point cloud may not be able to mine the local structure of the point cloud well.

PointNet++做法:learn to combine features from regions of different scales when the input sampling density changes.

Therefore, the article proposes two solutions:
1. Multi-scale grouping (MSG)
insert image description here
takes query balls of different radiuses for each center point of the current layer, and can obtain multiple concentric balls of different sizes, that is, multiple concentric balls of the same size. Local neighborhoods in the center but with different scales, represent these local neighborhoods separately, and stitch all the representations together. As shown in FIG.

The code level is actually to add a loop to traverse the radius_list, process them separately, and finally concat

new_xyz = gather_point(xyz, farthest_point_sample(npoint, xyz))
new_points_list = []
for i in range(len(radius_list)):
    radius = radius_list[i]
    nsample = nsample_list[i]
    idx, pts_cnt = query_ball_point(radius, nsample, xyz, new_xyz)
    grouped_xyz = group_point(xyz, idx)
    grouped_xyz -= tf.tile(tf.expand_dims(new_xyz, 2), [1,1,nsample,1])
    if points is not None:
        grouped_points = group_point(points, idx)
        if use_xyz:
            grouped_points = tf.concat([grouped_points, grouped_xyz], axis=-1)
    else:
        grouped_points = grouped_xyz
    if use_nchw: grouped_points = tf.transpose(grouped_points, [0,3,1,2])
    for j,num_out_channel in enumerate(mlp_list[i]):
        grouped_points = tf_util.conv2d(grouped_points, num_out_channel, [1,1],
                                        padding='VALID', stride=[1,1], bn=bn, is_training=is_training,
                                        scope='conv%d_%d'%(i,j), bn_decay=bn_decay)
    if use_nchw: grouped_points = tf.transpose(grouped_points, [0,2,3,1])
    new_points = tf.reduce_max(grouped_points, axis=[2])
    new_points_list.append(new_points)
new_points_concat = tf.concat(new_points_list, axis=-1)

2、Multi-resolution grouping(MRG):
insert image description here

The local neighborhood representation of the current set abstraction consists of two parts:

Representation on the left: aggregate the features of each local neighborhood (or central point) of the previous layer of set abstraction (remember that the point scale of the previous layer is larger?)

Representation on the right: use a single PointNet to directly process the original point cloud in the local neighborhood

2.2 decoder

2-2-1 Decoder for classification tasks

insert image description hereIt is relatively simple. The global feature obtained by the encoder downsampling is sent to several layers of fully connected networks, and finally classified by a softmax.

2-2-2 Decoder for split tasks

insert image description hereAfter the encoder in the first half, what we get is a global feature, or a representation of a very small number of points (in fact, it is a global feature).
If we do segmentation, what we need is a point-wise feature. What can we do?

The idea of ​​PointNet processing is very simple. The global feature is directly copied and spliced ​​with the previous local feature, so that this new point-wise feature can obtain a certain degree of "neighborhood" information. This simple and crude method obviously cannot get a very discriminative characterization

Don't worry, PointNet++ is here.

PointNet++ designed a reverse interpolation method to realize the upsampled decoder structure, through reverse interpolation and skip connection to obtain discriminative point-wise feature:
insert image description here
set red rectangular point set P1: N1 * C, blue rectangular point set P2 : N2 * C2, because the decoder is an upsampling process, so N2>N1

1. The specific method of reverse interpolation : For each point x in
insert image description here
the blue matrix point set P2, find the k points x1, x2...xk closest to it in P1 in the original point cloud coordinate space.
The interpolation method uses fewer points to insert the features of more points to achieve upsampling.
At this time, we know the features of x1, x2...xk, and we want to get the features of x
as the above formula, in fact, x1 , x2...xk, the feature weighted summation, get the feature of x. Among them, this weight is inversely related to the distance between x and x1, x2...xk, which means that the farther the distance is, the smaller the contribution to the x feature (see the right side of the above formula)

The other points in P2 can be deduced by analogy, thus realizing the upsampling and returning of features.

2. The specific method of skip connection
: the point-wise feature obtained by the return is obtained from the upper layer of the decoder, so it is considered as global-level information, which is not enough for discriminative, because we still lack local-level information! ! !

As shown in the figure above, we only get C2 through reverse interpolation, but we also need to provide C1 features of local level information! ! !

At this time, the skip connection is here! ! !
insert image description here
The skip connection is actually a direct splicing of the representations of the corresponding layers of the previous encoder,
because the C1 representation of the blue rectangle point set of the encoder in the above figure comes from the representation of a larger green rectangle point set, which is actually realized to a certain extent. local level information

We use reverse interpolation and skip connection to obtain local + global point-wise features step by step in the decoder, and obtain discriminative features, which are applied to segmentation tasks.

2.3 loss

Whether it is classification or segmentation application, it is essentially a classification problem, so loss is the cross entropy loss commonly used in classification tasks

2.4 Other issues

Q: How is the PointNet++ gradient returned? ? ?

A: PointNet++ fps does not actually participate in gradient calculation and backpropagation.

It can be understood that PointNet++ performs fps downsampling of point clouds of different scales, prepares these data in advance, and then sends them to the network for training

3. dataset

The data set is the same as pointnet

4.experiments

The main experimental results of concern are 2:

ModelNet40 classification results
ShapeNet Part segmentation results
insert image description here
insert image description here

5.conclusion

PointNet++ is the sequel to PointNet, which makes up for some defects of PointNet to a certain extent. The representation network is basically similar to PN, and it is still MLP, 1*1 convolution, and pooling. The core innovation lies in the design of local neighborhood sampling. The representation method and the network structure of this multi-level encoder-decoder combination.

When I saw the PointNet++ network structure for the first time, I felt that the design was very exquisite, especially the specific implementation method of upsampling and downsampling was designed, and it was used for the representation of segmentation tasks. I thought the design was too beautiful. But in fact, whether it is classification or segmentation tasks, the improvement is only 1-2 points compared to PointNet.

PointNet++, especially its first half encoder, provides a very good characterization network. Many papers on point cloud processing applications will use PointNet++ as their characterizer.

Guess you like

Origin blog.csdn.net/toCVer/article/details/125405057