Classify 3D point clouds with PointNet

In this tutorial, we will learn how to train PointNet for classification. We will mainly focus on the data and training process; a tutorial showing how to code a Point Net from scratch is located here. The code for this tutorial is in this Github repository, and the notebook we'll be using is in this Github repository. Some of the code was inspired by this Github repository.

insert image description here

Recommendation: Use NSDT Designer to quickly build programmable 3D scenes.

1. Get data

We will use a smaller version of the shapenet dataset with only 16 classes. If you are using Colab, you can run the following code to get the data. Warning, this will take a long time.

!wget -nv https://shapenet.cs.stanford.edu/ericyi/shapenetcore_partanno_segmentation_benchmark_v0.zip --no-check-certificate
!unzip shapenetcore_partanno_segmentation_benchmark_v0.zip
!rm shapenetcore_partanno_segmentation_benchmark_v0.zip

If you want to run locally, visit the link on the first line above and the data will be automatically downloaded as a zip file.

The dataset contains 16 folders with class identifiers (called "synsetoffset" in the readme). The folder structure is:

synsetoffset
  |- points                  # 来自 ShapeNetCore 模型的均匀采样点
  |- point_labels            # 每点分割标签
  |- seg_img                 #标签的可视化
train_test_split:           #带有训练/验证/测试拆分的 JSON 文件

The custom PyTorch dataset is located here, explaining the code is beyond the scope of this tutorial. The important thing to understand is that the dataset can be taken as (point_cloud, class) or (point_cloud, seg_labels). During training and validation, we add Gaussian noise to the point clouds and randomly rotate them around the vertical axis (y-axis in this case). We also perform min-max normalization on the point clouds so that they range from 0-1. We can create an instance of the shapenet dataset like this:

from shapenet_dataset import ShapenetDataset

# __getitem__ returns (point_cloud, class)
train_dataset = ShapenetDataset(ROOT, npoints=2500, split='train', classification=True)

2. Explore data

Before we start any training, let's explore some training data. For this we will use Open3d version 0.16.0 (must be 0.16.0 or higher).

!pip install open3d==0.16.0

We can now view a sample point cloud using the following code. You should notice that the point cloud is displayed in a different orientation each time you run the code.

import open3d as o3
from shapenet_dataset import ShapenetDataset

sample_dataset = train_dataset = ShapenetDataset(ROOT, npoints=20000, split='train', 
                                                 classification=False, normalize=False)

points, seg = sample_dataset[4000]

pcd = o3.geometry.PointCloud()
pcd.points = o3.utility.Vector3dVector(points)
pcd.colors = o3.utility.Vector3dVector(read_pointnet_colors(seg.numpy()))

o3.visualization.draw_plotly([pcd])

insert image description here

Figure 1. Randomly rotated noisy point cloud. Y axis is the vertical axis

You probably won't notice much of a difference in the noise because we added a small amount; we added a small amount because we didn't want to break the structure too much, but this small amount is enough to have an effect on the model. Now let's take a look at the data frequency of the training class.
insert image description here

Figure 2. Histogram of training classification data points

From Figure 2 we can see that this is definitely not a balanced training set. Therefore, we may want to apply category weights, or even use a focal loss to help our model learn.

3. PointNet loss function

When training PointNet for classification, we can use the standard cross-entropy loss in PyTorch, but we also want to add the regularization term mentioned in the paper.

The regularization term forces the feature transformation matrices to be orthogonal, but why? Feature transformation matrices aim to rotate (transform) high-dimensional representations of point clouds. How can we be sure that this learned high-dimensional rotation is actually rotating the point cloud? To answer this question, let's consider some desired rotation properties.

We want the learned rotation to be affine, meaning it preserves structure. We want to make sure it doesn't do weird things like map it back into a lower dimensional space or mess up the structure. We can't just plot nx64 point clouds to check this, but we can let the model learn efficient rotations by encouraging rotations to be orthogonal. This is because an orthogonal matrix preserves both length and angle, and a rotation matrix is ​​a special type of orthogonal matrix. We can "encourage" the model to learn an orthogonal rotation matrix by regularizing with:

insert image description here

Figure 3. PointNet regularization term

We exploit a fundamental property of orthogonal matrices that their columns and rows are orthogonal vectors. For a perfectly orthogonal matrix, the regularization term in Figure 3 will be equal to zero.

During training, we simply add this to our loss. If you have gone through the previous tutorial on how to encode PointNet, you may recall that the feature transformation matrix A is returned by the classification head.

Now let's code the PointNet loss function. We've added the terms weighted (balanced) cross-entropy loss and focal loss, but explaining them is beyond the scope of this tutorial. Its code is located here. The code was adapted from this Github repository.

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F


class PointNetLoss(nn.Module):
    def __init__(self, alpha=None, gamma=0, reg_weight=0, size_average=True):
        super(PointNetLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.reg_weight = reg_weight
        self.size_average = size_average

        # sanitize inputs
        if isinstance(alpha,(float, int)): self.alpha = torch.Tensor([alpha,1-alpha])
        if isinstance(alpha,(list, np.ndarray)): self.alpha = torch.Tensor(alpha)

        # get Balanced Cross Entropy Loss
        self.cross_entropy_loss = nn.CrossEntropyLoss(weight=self.alpha)

    def forward(self, predictions, targets, A):

        # get batch size
        bs = predictions.size(0)

        # get Balanced Cross Entropy Loss
        ce_loss = self.cross_entropy_loss(predictions, targets)

        # reformat predictions and targets (segmentation only)
        if len(predictions.shape) > 2:
            predictions = predictions.transpose(1, 2) # (b, c, n) -> (b, n, c)
            predictions = predictions.contiguous() \
                                     .view(-1, predictions.size(2)) # (b, n, c) -> (b*n, c)

        # get predicted class probabilities for the true class
        pn = F.softmax(predictions)
        pn = pn.gather(1, targets.view(-1, 1)).view(-1)

        # get regularization term
        if self.reg_weight > 0:
            I = torch.eye(64).unsqueeze(0).repeat(A.shape[0], 1, 1) # .to(device)
            if A.is_cuda: I = I.cuda()
            reg = torch.linalg.norm(I - torch.bmm(A, A.transpose(2, 1)))
            reg = self.reg_weight*reg/bs
        else:
            reg = 0

        # compute loss (negative sign is included in ce_loss)
        loss = ((1 - pn)**self.gamma * ce_loss)
        if self.size_average: return loss.mean() + reg
        else: return loss.sum() + reg

4. Training PointNet for classification

Now that we understand the data and the loss function, we can move on to training.

For our training we need to quantify the performance of the model. Normally we think about loss and accuracy, but for this classification problem, we need a metric that measures misclassification versus correct classification. Think of a typical confusion matrix: true positives, false negatives, true negatives, and false positives; we want a classifier that performs well on all of them.

The Matthews Correlation Coefficient (MCC) quantifies the performance of our model on all these metrics and is considered to be a more reliable single performance metric than accuracy or F1-score. MCC ranges from -1 to 1, where -1 is the worst performance, 1 is the best performance, and 0 is a random guess. We can use MCC with PyTorch through torchmetrics.

from torchmetrics.classification import MulticlassMatthewsCorrCoef

mcc_metric = MulticlassMatthewsCorrCoef(num_classes=NUM_CLASSES).to(DEVICE)

The training process is a basic PyTorch training loop, alternating between training and validation.

We use the Adam optimizer with our point net loss function and the regularization term described in Figure 3 above. For the point net loss function, we choose to set alpha, which weights the importance of each sample.

We also set gamma to tune the loss function and force it to focus on hard examples, where hard examples are those that are classified with lower probability. See notes in notebook for more details. It was noticed that the model trained better when using a cyclic learning rate, so we implemented it here.

import torch.optim as optim
from point_net_loss import PointNetLoss

EPOCHS = 50
LR = 0.0001
REG_WEIGHT = 0.001 

# manually downweight the high frequency classes
alpha = np.ones(NUM_CLASSES)
alpha[0] = 0.5  # airplane
alpha[4] = 0.5  # chair
alpha[-1] = 0.5 # table

gamma = 1

optimizer = optim.Adam(classifier.parameters(), lr=LR)
scheduler = torch.optim.lr_scheduler.CyclicLR(optimizer, base_lr=0.0001, max_lr=0.01, 
                                              step_size_up=2000, cycle_momentum=False)
criterion = PointNetLoss(alpha=alpha, gamma=gamma, reg_weight=REG_WEIGHT).to(DEVICE)

classifier = classifier.to(DEVICE)

Please follow the notebook for the training loop and make sure you have a GPU. If not, remove the scheduler and set the learning rate to 0.01, after a few epochs you should get good enough results. If you encounter any PyTorch user warnings (due to future updates of nn.MaxPool1D ), they can be suppressed by:

import warnings
warnings.filterwarnings("ignore")

5. Training results

insert image description here

We can see that both training and validation accuracy go up, but MCC only goes up during training but not during validation. This may be due to the very small sample size of some classes in the validation and test splits; thus MCC may not be the best single metric to use for validation and testing in this case. This requires more investigation to determine when MCC is a good indicator; ie how much imbalance is too much for MCC? How many samples per class are needed for MCC to be effective?

Let's take a look at the test results:

insert image description here

We see that the test accuracy is about 85%, but the MCC is slightly above 0. Since we only have 16 classes, let's look at the confusion matrix in the notebook to gain more insight into the test results.

insert image description here

Figure 6. Test data confusion matrix. Source: Author.

Categories are fine for the most part, but there are some less common ones like "rocket" or "skateboard". The model tends to have poor predictive performance on these classes, and performance on these less common classes is responsible for the drop in MCC.

Another thing to note is that when you inspect the results (as shown in the notebook), you will get good accuracy and confident performance in the more frequent classifications. However, in less frequent classes, you will find lower confidence and less accuracy.

6. Check the key set

Now we'll look at the most interesting part of this tutorial, the keyset. The key set is the basic base point of the point cloud set. These points define its basic structure. Here's some code showing how to visualize them.

from open3d.web_visualizer import draw 


critical_points = points[crit_idxs.squeeze(), :]
critical_point_colors = read_pointnet_colors(seg.numpy())[crit_idxs.cpu().squeeze(), :]

pcd = o3.geometry.PointCloud()
pcd.points = o3.utility.Vector3dVector(critical_points)
pcd.colors = o3.utility.Vector3dVector(critical_point_colors)

# o3.visualization.draw_plotly([pcd])
draw(pcd, point_size=5) # does not work in Colab

Here are some visualizations, note that I use "draw()" to get a larger point size, but it doesn't work in Colab.

insert image description here

Figure 7. Point cloud sets and their corresponding key sets learned by PointNet

We can see that keysets exhibit the overall structure of their corresponding point clouds, which are essentially sparsely sampled point clouds. This shows that the trained model has actually learned to discriminate the discriminative structure, and shows that it is actually able to classify each point cloud class based on its discriminative structure.

7. Conclusion

We learned how to train PointNet from scratch and how to visualize point sets. If you're really interested, try improving overall classification performance. Here are some suggestions to get you started:

  • Use a different loss function
  • Try different settings in the cyclic learning rate scheduler
  • Attempt to modify the PointNet architecture
  • Try different data augmentations
  • Use more data → try the full shapenet dataset

Original link: PointNet classification 3D point cloud - BimAnt

Guess you like

Origin blog.csdn.net/shebao3333/article/details/132174874