PointNet++ の分類とセグメンテーションの詳細な説明

序文

       PointNet++ は、不規則な形状の点群データの分類およびセグメンテーション タスクのためのディープ ニューラル ネットワークです。従来のグリッドベースの 3D データ表現方法と比較して、点群データの取得と処理が容易です。PointNet++ のもう 1 つの利点は、より複雑な点群データを処理できるマルチスケール階層を導入していることです。PointNet の最初のバージョンと比較して、著者は多くの新しいアイデアを提案し、非常に良い結果を達成しました。

PointNet アルゴリズムの問​​題

(1) 点群画像内の点が多すぎると、過剰な計算が発生し、アルゴリズムの速度が低下します。

(2) 点群を異なる領域に分割し、異なる領域の局所的特徴を取得するにはどうすればよいですか?

(3) 点群が不均一な場合、この問題を解決するにはどうすればよいでしょうか?

これらの疑問を念頭に置いて、私たちは次に論文とソースコードを通じてこれらの問題を解決し始めました。

分類タスク

8a4d9f3376a45c32f22c60ad476248c8.jpeg

階層抽出機能セットの抽象化

このモジュールは主に 3 つの部分で構成されます。

1. サンプリング層 (サンプル層): この記事の最初の問題を解決するために、密な点群からいくつかの比較的重要な点を中心点として抽出します。つまり、FPS (最遠点サンプリング) 最遠点サンプリング法です。 。

2. グループ レイヤー (グループ レイヤー): 中心点の近くに最も近い K 点を見つけて、ローカル ポイント領域を形成します。この操作は画像の畳み込みに似ており、簡単に特徴を抽出できるように畳み込み画像を形成します。2 番目の問題を解決します。

3. 特徴抽出層 (ポイントネット層): 特徴抽出層。各局所点領域の特徴を抽出します。

FPS のプロセスは次のとおりです。

(1) 最初のサンプリング点として点をランダムに選択します。

(2) 選択されていないサンプリング点セットと選択されたサンプリング点セットの各点間の距離を計算し、最も距離が大きい点を選択されたサンプリング点セットに追加します。

(3) 新しいサンプリング ポイントに従って距離を計算し、目標のサンプリング ポイント数が得られるまで継続的に繰り返します。

下図のように、全点の中から5点を選択します。

da1a282f9e73683da21c50df354f1151.png

コードは次のように実装されます

def farthest_point_sample(xyz, npoint):     
    """
    Input:
        xyz: pointcloud data, [B, N, 3]
        npoint: number of samples
    Return:
        centroids: sampled pointcloud index, [B, npoint]
    """
    device = xyz.device
    B, N, C = xyz.shape
    centroids = torch.zeros(B, npoint, dtype=torch.long).to(device)         #[b,npoint]    #npoint是要从很多的点中筛选出那么多
    distance = torch.ones(B, N).to(device) * 1e10                           #[b,N]         #N指原来有N个点
    farthest = torch.randint(0, N, (B,), dtype=torch.long).to(device)       #[b]           在0-N中随机生成了B个点(随机选了个点作为初始点,并且用的索引
    batch_indices = torch.arange(B, dtype=torch.long).to(device)            #[b]            0-b
    for i in range(npoint):
        centroids[:, i] = farthest
        centroid = xyz[batch_indices, farthest, :].view(B, 1, 3)           #获得索引对应的位置的坐标
        dist = torch.sum((xyz - centroid) ** 2, -1)                        #计算所有坐标和目前这个点的距离
        mask = dist < distance                                             #距离符合要求的
        distance[mask] = dist[mask]                                        #将符合要求的距离都放入
        farthest = torch.max(distance, -1)[1]                              #最远距离对应的索引  [b]
    return centroids                                                       #最终输出筛选的[n,npoint]个点的位置

パケット層のフローは次のとおりです。

(1) FPSによるスクリーニング後、対応する中心点を取得します。

(2) 原点セットの各中心点を使用して、距離に応じて近くに必要なポイントの数をフィルタリングし、各 FPS ポイントを中心とする新しいポイント セットを形成します。

(3) 新しい点セットは、座標正規化と同様の操作を実行して 3 つの新しい特徴を形成し、特徴抽出前に各点の元の特徴と組み合わせて新しい特徴を形成します。

簡略化した図は次のとおりです。

c44b141189eccb1afe7c9682cd5ef7a2.png

   図1  

3346b3e3963e5708250d5134402e3d21.png

図Ⅱ

       図 1 では、赤い点は FPS 結果の中心点であり、黒い点はいくつかの初期点です。図 2 の緑色の点は距離に応じてフィルタリングされた点であり、これらの点と赤色の点が一連の点セットを形成します。

       グループ化層では、著者は SSG (単一スケール グループ化)、MSG (マルチスケール グループ化) マルチスケール、および MRG (マルチ解像度グループ化) マルチ解像度の 3 つの方式を提案します。実際、複数のサンプリング グループが異なる半径または異なる解像度で実行されます。それはこの記事の 3 番目の問題を解決するためでもあります。

SSG: 半径が 1 つだけのグループ サンプリングと同等

9f36152f36d59afa2788301404fd4c5a.png

MSG: 同じ解像度で複数の半径グループをサンプリングし、点セットを結合することと同じです。

aa5ae0262f303f506022443365ce5abf.pngccc4a6e3d0932305999fc9383e037850.pngaa4484154577d1dd1e946f27ac613522.png

MRI

0e7daba91bfd9d4b254cc5b2d26d4526.png

コードは次のように実装されます

def query_ball_point(radius, nsample, xyz, new_xyz):
    """
    Input:
        radius: local region radius
        nsample: max sample number in local region   
        xyz: all points, [B, N, 3]
        new_xyz: query points, [B, S, 3]
    Return:
        group_idx: grouped points index, [B, S, nsample]
    """
    """
    1.预设搜索区域的半径R与子区域的点数K
    2.上面提取出来了 s 个点,作为 s个centriods。以这 s个点为球心,画半径为R的球体(叫做query ball,也就是搜索区域)。
    3.在每个以centriods的球心的球体内搜索离centriods最近的的点(按照距离从小到大排序,找到K个点)。
      如果query ball的点数量大于规模nsample,那么直接取前nsample个作为子区域;如果小于,那么直接对某个点重采样(此处直接复制了第一个点),凑够规模nsample
    """


    device = xyz.device
    B, N, C = xyz.shape
    _, S, _ = new_xyz.shape
    group_idx = torch.arange(N, dtype=torch.long).to(device).view(1, 1, N).repeat([B, S, 1])    #[2,512,1024]
    sqrdists = square_distance(new_xyz, xyz)    #获得采样后点集与原先点集的距离[2,512,1024]
    group_idx[sqrdists > radius ** 2] = N       #将距离比半径大的,将此处置为1024
    group_idx = group_idx.sort(dim=-1)[0][:, :, :nsample]  #因为是0-1023的索引,不符合要求的变味了1024,再对索引排序,获得前nsample个  [b,s,nsample]
    group_first = group_idx[:, :, 0].view(B, S, 1).repeat([1, 1, nsample])  #拿了符合要求的第一个索引点索引,也就是中心点,并且复制了nsample次     [b,s,nsample]
    mask = group_idx == N                #查看那前nsample中有没有不符合要求的
    group_idx[mask] = group_first[mask]   #其中将不符合要求的点,全部换成符合要求的第一个点
    return group_idx   


  def sample_and_group(npoint, radius, nsample, xyz, points, returnfps=False):
    """
    Input:
        npoint:
        radius:
        nsample:
        xyz: input points position data, [B, N, 3]
        points: input points data, [B, N, D]
    Return:
        new_xyz: sampled points position data, [B, npoint, nsample, 3]
        new_points: sampled points data, [B, npoint, nsample, 3+D]
    """
    B, N, C = xyz.shape
    S = npoint
    fps_idx = farthest_point_sample(xyz, npoint) # [B, npoint]        #FPS最远点采样算法  获得需要点的索引
    new_xyz = index_points(xyz, fps_idx)          # [B, npoint,c]                        #将对应索引的点拿出来
    idx = query_ball_point(radius, nsample, xyz, new_xyz)       #[b,npoint,nsample]         #进行query_ball处理,类似于卷积操作,找到new_xyz附近原xyz的nsample个点,返回对应的索引[b,npoint,nsample]
    grouped_xyz = index_points(xyz, idx) # [B, npoint, nsample, C]       #将对应索引的点拿出来
    grouped_xyz_norm = grouped_xyz - new_xyz.view(B, S, 1, C)   #每个点减去自己半径内中心的那个点进行归一化[B, npoint, nsample, C]


    if points is not None:   #points即原来点就存才的一些特征
        grouped_points = index_points(points, idx)       #将每个区域原先的特征拿出来   [b,npoint,nsample,D]
        new_points = torch.cat([grouped_xyz_norm, grouped_points], dim=-1) # [B, npoint, nsample, C+D]   #将归一化数据和原先的特征结合
    else:
        new_points = grouped_xyz_norm                    #如果原先没有特征的,那么数据就是归一化后的点
    if returnfps:
        return new_xyz, new_points, grouped_xyz, fps_idx
    else:
        return new_xyz, new_points

特徴抽出層

       この層は、いくつかの基本的な畳み込みプーリング操作です。最後に、FPS によって選択されるポイント フィーチャが増加します。

e1e2cabd55575cc3a06a62d25bf6685a.png45843d2325cec2c04e0775a6c64bbbfd.png

       設定した抽象化レイヤーを数回繰り返し、最後に完全に接続されたネットワークをいくつか接続して点群を分類します。

分割タスク

4b41ac3b32c82a46e2671fbf669cebaf.png

       セグメンテーションタスクの特徴抽出器は分類タスクの特徴抽出器と同じですが、次に主にアップサンプリングリンクについて説明します。著者は、距離補間に基づく階層的特徴伝播 (Feature Propagation) 戦略を提案します。

一般的なプロセスは次のとおりです。

(1) 以下の図に示すように、逆距離加重平均の重みを計算します。赤い点は FPS 特徴抽出後の点です。実際、赤い点の特徴の数は黒い点の特徴の数よりも多くなります。ポイント。そして、アップサンプリングは、これらの黒い点も一致する特徴を生成するようにすることです。

2251d7ded991ae40f9bae7824f5ac153.png

重みの計算: 各黒い点について、最も近い 3 つの赤い点を見つけて、各赤い点の距離の重みに従って特徴を黒い点に割り当てます。次に、各黒い点が新しい特徴を生成します。赤い点も同じことを行う必要があります。つまり、上図のすべての点で新しいフィーチャが生成されます。

7af870025b17fa75334da7b853eed9ae.pngbc119e8b1403fadf124ec834ca6b05b2.png718fc4bece948c7ea45bb123e7bd9bca.pngd361bd6ea005831c92ee017c4135b41c.png

(2) 生成された新しい特徴は、前の層の特徴を使用して cat 操作され、畳み込みによって特徴の融合が完了します。ワンステップアップサンプリングが終了します。

コードは次のように実装されます

class PointNetFeaturePropagation(nn.Module):  #上采样
    def __init__(self, in_channel, mlp):
        super(PointNetFeaturePropagation, self).__init__()
        self.mlp_convs = nn.ModuleList()
        self.mlp_bns = nn.ModuleList()
        last_channel = in_channel
        for out_channel in mlp:
            self.mlp_convs.append(nn.Conv1d(last_channel, out_channel, 1))
            self.mlp_bns.append(nn.BatchNorm1d(out_channel))
            last_channel = out_channel


    def forward(self, xyz1, xyz2, points1, points2):
        """
        Input:
            xyz1: input points position data, [B, C, N]      上一层的
            xyz2: sampled input points position data, [B, C, S]  现在层的
            points1: input points data, [B, D1, N]            上一层的
            points2: input points data, [B, D2, S]            现在层的
        Return:
            new_points: upsampled points data, [B, D', N]
        """
        xyz1 = xyz1.permute(0, 2, 1)
        xyz2 = xyz2.permute(0, 2, 1)


        points2 = points2.permute(0, 2, 1)
        B, N, C = xyz1.shape
        _, S, _ = xyz2.shape


        if S == 1:
            interpolated_points = points2.repeat(1, N, 1)
        else:
            dists = square_distance(xyz1, xyz2)    #计算两点集各点之间距离[b,N,S]
            dists, idx = dists.sort(dim=-1)        #排序
            dists, idx = dists[:, :, :3], idx[:, :, :3]  # [B, N, 3]  获得距离最近的3点


            dist_recip = 1.0 / (dists + 1e-8)
            norm = torch.sum(dist_recip, dim=2, keepdim=True)
            weight = dist_recip / norm              #根据距离设置权重
            interpolated_points = torch.sum(index_points(points2, idx) * weight.view(B, N, 3, 1), dim=2)   #[B,N,D2]


        if points1 is not None:
            points1 = points1.permute(0, 2, 1)
            new_points = torch.cat([points1, interpolated_points], dim=-1)   #上采样cat  [B,N,D1+D2]
        else:
            new_points = interpolated_points


        new_points = new_points.permute(0, 2, 1)                             #[B,D1+D2,N]
        for i, conv in enumerate(self.mlp_convs):
            bn = self.mlp_bns[i]
            new_points = F.relu(bn(conv(new_points)))
        return new_points         #[B, D', N]

       以上でPointNet++に関する記事分析は終了です!誤解がある場合は、批判と修正を歓迎します。一緒に進歩しましょう!

おすすめ

転載: blog.csdn.net/weixin_41202834/article/details/130313302