PointNet++:Deep Hierarchical Feature Learning on Point Sets in a Metric Space

在上一篇文章中,提及了3D点云分类与分割的开山鼻祖——PointNet:https://blog.csdn.net/Alkaid2000/article/details/127253473,但是这篇PointNet是存在有很多不足之处的,在文章的末尾也提及了,它没有能力捕获局部结构,这使得在复杂的场景中也很难进行分析,道理也很简单,这篇文章只使用了Max操作以及MLP操作,也不符合当前神经网络的主流。PointNet++的作者主要通过两个主要的方法进行了改进,使得网络能更好的提取局部特征:


论文地址: https://arxiv.org/pdf/1706.02413.pdf

GitHub:https://github.com/yanx27/Pointnet_Pointnet2_pytorch

在这里插入图片描述

0x01 PointNet的存在问题

首先我们可以明显的感觉到这个网络与当下主流网络不符,没有局部特征融合,要不自己,要么一个整体。这个PointNet网络也没有关系概念,局部样本点之间肯定存在关系的,它没有考虑到这一层。于是PointNet++就从局部入手,多利用了局部特征。整体思想是不变的,只不过是在特征提取处使用类似图卷积的方式来整合特征。

  • PointNet的MLP,仅仅是对每个点的表征,对局部结构信息整合能力太弱(PointNet++改进:sampling和grouping整合局部邻域

  • global feature直接由max pooling获得,无论对分类还是对分割任务,都会造成巨大的信息损失(PointNet++改进:hierarchical feature learning framework,通过多个set abstraction逐级降采样,获得不同规模不同层次的local-global feature

  • 分割任务的全局特征global feature是直接复制与local feature拼接,生成discriminative feature能力有限(PointNet++的改进:分割任务设计了encoder-decoder结构,先降采样再上采样,使用skip connection将对应层的local-global feature拼接

PointNet++需要解决两个关键的问题:第一,如何将点集划分为不同的区域;第二,如何利用特征提取器获取不同区域的局部特征

提到这个局部特征提取,可能大家都会想到CNN,因为在二维的图像中,卷积块成为了基本的特征提取器;但在3D点云中,我们同样需要找到结构相同的子区域和对应的区域特征提取器。在本文中,作者使用了PointNet作为特征提取器,另外一个问题就是如何来划分点集从而产生结构相同的区域,作者使用邻域球来定义分区,每个区域可以通过中心坐标和半径来确定。中心坐标的选取,作者使用了最远采样点算法来实现。

0x02 PointNet++算法解读

PointNet++是PointNet的延伸,在PointNet的基础上加入了多层次结构(hierarchical structure),使得网络能够在越来越大的区域上提供更高级别的特征。网络中的每一组set abstraction layers主要包括三个部分:

  • Sample layer:主要是对输入点进行采样,在这些点中选出若干个中心点。
  • Grouping layer:是利用上一步得到的中心点将点集划分成若干个区域。
  • PointNet layer:是对上述得到的每个区域进行编码,变成特征向量。

(一)最远采样方法(farthest point sampling)

基于半径选择局部区域(类似于得到很多个簇),之后针对得到的每个区域进行特征提取(卷积),那么要解决的问题就是:如何选择区域(簇中心点选择),簇的半径大小如何定义,每个簇中选择多少个样本点。

在这里插入图片描述

  • 先确定好每一个局部区域,接下来对局部区域执行pointnet:
    在这里插入图片描述

例如我们现在需要输入1024个点,要选择128个中心点(簇),要如何采样呢:

在这里插入图片描述

选簇的这个感觉有点像是下采样的感觉,要尽可能地覆盖全部的数据,所以提出了最远点采样。比如最开始我们将点设置在了点云的头部,那么下一个点必须离我当前这个点最远,这样尽可能能覆盖;那么第三个点的确定,首先我们先画下两个点,需要确定这两个点距离第一个点的距离d1和第二个点距离d2,根据最远采样原则,我们要保留距离最大的那个点,成为我们第三个点。原则:离我其他已采样点是最远的。这样才可以尽可能覆盖到我们整个点云。

(二)分组(grouping)

假设我们现在得到了128个中心点,那么我们原始的输入其实是batch * 1024 * 6(1024个点,每个点对应3个坐标3个法向量信息),那么我经过最远采样后,我们将中心点以一定的半径将其包裹起来,所有的中心点圈到的点的数量必须一致,如果这个范围内不满足我们要求的点,那我们就会复制离中心点最近的点,直到补齐;如果比我们设定的点多的话,那么首先我们需要进行排序,以距离中心点的距离排序,从小到大排序,排序完后我们剔除离我们远的那些点,直到满足条件。假设我们一个圈中,需要有16个样本,分组后输出为:batch* 128 * 16 * 6(128个中心点,每个簇16个样本)。实际计算时是选择多种半径,多种样本点个数,目的是为了特征更丰富:例如半径=(0.1,0.2,0.4),那么对应簇的样本个数(16,32,64)。这样就可以通过多种半径的圈,可以得到多种特征,兼顾了局部与全局。

(三)对各组进行特征提取

  • 先进行维度变换(b* npoints * nsample * features,8128 * 16 * 6-> 8 6 * 16* 128),其中的npoints可以理解为channel。
  • 进行卷积操作(例如:in:6,out:64)就得到提取的特征(8 * 64 * 16 *128)。
  • 注意当前每个簇都是16个样本点,我们要每一个簇对应一个特征。
  • 按照pointnet,做max操作,得到8 * 64 * 128(得到64维特征)
  • 继续做多次采样,分组,卷积。
    • 例如:采样中心点(1024->512->128)
    • 每一次操作时,都要进行特征拼接(无论半径为0.1,0.2,0.4;以及簇采样点个数)
    • 最终都得到batch* 中心点个数 * 特征(但是特征个数可能不同)
    • 执行拼接操作(b * 512 * 128,b* 512 * 256,b* 512 * 512)->(b* 512 * 869)

那么在这里就结束了我们的特征提取了,接下来就是要对我们的特征进行分类。

(四)分类整体网络架构

在这里插入图片描述

那么先开始输入:其中的N可以理解为我们现在得到了多少个样本点,d可以代表我们现在当前的位置,C可以代表当前的特征。之后我们就做了一些簇,进行了pointnet,进行了最远采样操作,位置信息保持不变,C的特征数就进行改变了,通过不断的sampling与grouping,最终可以不断地更新C的值,接下来就是进行分类任务了。

首先先进行一次pointnet操作,其实也就是max操作,就可以得到一个实际的向量,再连全连接层,做一个分类任务。

(五)分割整体网络架构

在这里插入图片描述

分割任务有些不同,要得到每个点的特征,还需要进行上采样的操作。因为我们取了特征,点变少了,我们需要分割的话,要对大体进行分割,需要多一些的点,所以需要上采样,这个感觉会不会比较像是我们的点云补全?补全的时候是通过权重参数与距离相结合,得到新的特征后在与原先的点进行拼接;再继续做上采样,再次拼接,直到补全了我们输入的大小。

0x03 PointNet++遇到的问题

上面是整个PointNet++网络的基本流程,其实它很容易遇到这么一个问题,那就是样本点的个数,很容易受样本点的个数影响

在这里插入图片描述

随着点的减小,准确率会越来越低。那么如何改进呢:

(一)点云分布不一致的处理方法

点云分布不一致时,每个区域中如果生成的时候使用相同的半径r,会导致有些区域采样点过少。作者提到这个问题需要解决,并且提出了两个方法:Multi-scale grouping (MSG) and Multi-resolution grouping (MRG)

在这里插入图片描述

多半径进行特征拼接 或者 跨层来提取不同分辨率特征。

MSG:使用了不同的半径进行读取特征,最后将其拼接在一起。(源码中做的)

MRG:使用不同的分辨率,也就是分层来进行提取特征,最后将其拼接在一起。

改进之后的效果如下:

在这里插入图片描述

可以发现采样点个数下降时候不会严重影响结果,稳定性也得以提升。

论文的最后也提出了,PointNet++相比于PointNet还是取得了很好的效果,主要还是应用于分类与分割中:

在这里插入图片描述

在这里插入图片描述

0x04 源码阅读

在这里插入图片描述

在上面可以看到pointnet的一些文件,这是作为特征提取器的,也是原来的pointnet的代码。可以注意到文件名有msg以及ssg为文件名,区别在于msg使用的是多个半径来取特征点,而ssg只使用了一个半径。

接下来是设置训练的参数:

在这里插入图片描述

在开始训练的时候只需要:

python test_cls.py --model pointnet2_cls_msg --normal  --log_dir pointnet2_cls_msg

(一)环境配置

conda create -n PointNet++ python==3.7
conda activate PointNet++
conda install pytorch==1.6.0 cudatoolkit=10.1 -c pytorch
pip install tqdm

(二)数据读取模块配置

其读取数据模块位于ModelNetDataLoader.py文件中的ModelNetDataLoader,首先要进行相关的初始化:

    def __init__(self, root,  npoint=1024, split='train', uniform=False, normal_channel=True, cache_size=15000):
        self.root = root
        self.npoints = npoint
        self.uniform = uniform
        self.catfile = os.path.join(self.root, 'modelnet40_shape_names.txt')

        self.cat = [line.rstrip() for line in open(self.catfile)]
        self.classes = dict(zip(self.cat, range(len(self.cat))))
        self.normal_channel = normal_channel

        shape_ids = {
    
    }
        shape_ids['train'] = [line.rstrip() for line in open(os.path.join(self.root, 'modelnet40_train.txt'))]
        shape_ids['test'] = [line.rstrip() for line in open(os.path.join(self.root, 'modelnet40_test.txt'))]
		
        assert (split == 'train' or split == 'test')
        shape_names = ['_'.join(x.split('_')[0:-1]) for x in shape_ids[split]]
        # list of (shape_name, shape_txt_file_path) tuple
        self.datapath = [(shape_names[i], os.path.join(self.root, shape_names[i], shape_ids[split][i]) + '.txt') for i
                         in range(len(shape_ids[split]))]
        print('The size of %s data is %d'%(split,len(self.datapath)))

        self.cache_size = cache_size  # how many data points to cache in memory
        self.cache = {
    
    }  # from index to (point_set, cls) tuple

可以看看对应的数据集:

在这里插入图片描述

之后将其类别以及id存进来,再读对应的坐标点:

    def _get_item(self, index):
        if index in self.cache:
            point_set, cls = self.cache[index]
        else:
            fn = self.datapath[index]
            cls = self.classes[self.datapath[index][0]]
            cls = np.array([cls]).astype(np.int32)
            point_set = np.loadtxt(fn[1], delimiter=',').astype(np.float32)
            #读入xyz特征以及额外信息特征
            if self.uniform:
                point_set = farthest_point_sample(point_set, self.npoints)
            #采样,只取其中一部分
            else:
                point_set = point_set[0:self.npoints,:]

            point_set[:, 0:3] = pc_normalize(point_set[:, 0:3])#相当于3列值,做坐标值标准化

            if not self.normal_channel:
                point_set = point_set[:, 0:3]

            if len(self.cache) < self.cache_size:
                #标准化以及标签
                self.cache[index] = (point_set, cls)

        return point_set, cls

(二)网络模型架构

可以在文件pointnet2_cls_msg.py中可以看到forward函数:

    def forward(self, xyz):
        B, _, _ = xyz.shape
        if self.normal_channel:
            norm = xyz[:, 3:, :]
            xyz = xyz[:, :3, :]	#位置信息
        else:
            norm = None
        print(xyz.shape)
        print(norm.shape)
        l1_xyz, l1_points = self.sa1(xyz, norm)
        l2_xyz, l2_points = self.sa2(l1_xyz, l1_points)  
        l3_xyz, l3_points = self.sa3(l2_xyz, l2_points)
        x = l3_points.view(B, 1024)
        x = self.drop1(F.relu(self.bn1(self.fc1(x))))
        x = self.drop2(F.relu(self.bn2(self.fc2(x))))
        x = self.fc3(x)
        x = F.log_softmax(x, -1)


        return x,l3_points

在进入forward之前,我们首先要看看这些东西它经历了哪些层:

    def __init__(self,num_class,normal_channel=True):
        super(get_model, self).__init__()
        in_channel = 3 if normal_channel else 0
        self.normal_channel = normal_channel
        self.sa1 = PointNetSetAbstractionMsg(512, [0.1, 0.2, 0.4], [16, 32, 128], in_channel,[[32, 32, 64], [64, 64, 128], [64, 96, 128]])
        self.sa2 = PointNetSetAbstractionMsg(128, [0.2, 0.4, 0.8], [32, 64, 128], 320,[[64, 64, 128], [128, 128, 256], [128, 128, 256]])
        self.sa3 = PointNetSetAbstraction(None, None, None, 640 + 3, [256, 512, 1024], True)
        self.fc1 = nn.Linear(1024, 512)
        self.bn1 = nn.BatchNorm1d(512)
        self.drop1 = nn.Dropout(0.4)
        self.fc2 = nn.Linear(512, 256)
        self.bn2 = nn.BatchNorm1d(256)
        self.drop2 = nn.Dropout(0.5)
        self.fc3 = nn.Linear(256, num_class)

那么其中的PointNetSetAbstractionMsg,其实就是规定了多少个中心点以及圆的半径,下面就是各种采样的操作了。

那么在这句开始,我们就已经进入网络了:

l1_xyz, l1_points = self.sa1(xyz, norm)

他会进入pointnet_util.py这个文件中的PointNetSetAbstractionMsg类中,也就是我们网络模型开始取特征的开始啦:

    def forward(self, xyz, points):
        """
        Input:
            xyz: input points position data, [B, C, N]
            points: input points data, [B, D, N]
        Return:
            new_xyz: sampled points position data, [B, C, S]
            new_points_concat: sample points feature data, [B, D', S]
        """
        xyz = xyz.permute(0, 2, 1) #就是坐标点位置特征
        print(xyz.shape)
        if points is not None:
            points = points.permute(0, 2, 1) #就是额外提取的特征,第一次的时候就是那个法向量特征
        print(points.shape)
        B, N, C = xyz.shape
        S = self.npoint #表示选多少个中心点
        new_xyz = index_points(xyz, farthest_point_sample(xyz, S))#采样后的点,最远点采样函数
        print(new_xyz.shape)
        new_points_list = []
        for i, radius in enumerate(self.radius_list):
            K = self.nsample_list[i]
            group_idx = query_ball_point(radius, K, xyz, new_xyz)#返回的是索引
            grouped_xyz = index_points(xyz, group_idx)#得到各个组中实际点
            grouped_xyz -= new_xyz.view(B, S, 1, C)#去mean new_xyz相当于簇的中心点
            if points is not None:
                grouped_points = index_points(points, group_idx)
                grouped_points = torch.cat([grouped_points, grouped_xyz], dim=-1)
                print(grouped_points.shape)
            else:
                grouped_points = grouped_xyz

            grouped_points = grouped_points.permute(0, 3, 2, 1)  # [B, D, K, S]
            print(grouped_points.shape)
            for j in range(len(self.conv_blocks[i])):
                conv = self.conv_blocks[i][j]
                bn = self.bn_blocks[i][j]
                grouped_points =  F.relu(bn(conv(grouped_points)))
            print(grouped_points.shape)
            new_points = torch.max(grouped_points, 2)[0]  # [B, D', S] 就是pointnet里的maxpool操作
            print(new_points.shape)
            new_points_list.append(new_points)

        new_xyz = new_xyz.permute(0, 2, 1)
        new_points_concat = torch.cat(new_points_list, dim=1)
        print(new_points_concat.shape)
        return new_xyz, new_points_concat

上面函数输入了位置信息xyz以及特征信息points,原始输入的话那么这个points就是我们的法向量啦,接下来这个返回值,就是我们计算后得到的这个点,下一步要进行下采样或者是各种操作,提特征也是通过多种半径进行提取的,最后将其拼接在一起。

那么最远点采样的实现细节:

def farthest_point_sample(xyz, npoint):
    """
    Input:
        xyz: pointcloud data, [B, N, 3]
        npoint: number of samples
    Return:
        centroids: sampled pointcloud index, [B, npoint]
    """
    device = xyz.device
    B, N, C = xyz.shape
    centroids = torch.zeros(B, npoint, dtype=torch.long).to(device)# 8*512
    #距离矩阵
    distance = torch.ones(B, N).to(device) * 1e10 # 8*1024
    #第一个点是随机来的,后面的点才有距离计算
    farthest = torch.randint(0, N, (B,), dtype=torch.long).to(device)# batch里每个样本随机初始化一个最远点的索引
    batch_indices = torch.arange(B, dtype=torch.long).to(device)
    for i in range(npoint):
        centroids[:, i] = farthest # 第一个采样点选随机初始化的索引
        centroid = xyz[batch_indices, farthest, :].view(B, 1, 3)# 得到当前采样点的坐标 B*3
        dist = torch.sum((xyz - centroid) ** 2, -1)# 计算当前采样点与其他点的距离
        mask = dist < distance# 选择距离最近的来更新距离(更新维护这个表)
        distance[mask] = dist[mask]# 把最近的那个距离存下来
        farthest = torch.max(distance, -1)[1]# 重新计算得到最远点索引(在更新的表中选择距离最大的那个点),在最小距离中选择距离最大的点
    return centroids

他的返回值是告诉采样后的点哪个点是中心点,最后就可以得到所有的中心点啦,这个函数也是sampling的环节。那么这个函数结束后,他会把结果给到这个函数index_points:

def index_points(points, idx):
    """

    Input:
        points: input points data, [B, N, C]
        idx: sample index data, [B, S]
    Return:
        new_points:, indexed points data, [B, S, C]
    """
    device = points.device
    B = points.shape[0]
    view_shape = list(idx.shape)
    view_shape[1:] = [1] * (len(view_shape) - 1)
    repeat_shape = list(idx.shape)
    repeat_shape[0] = 1
    batch_indices = torch.arange(B, dtype=torch.long).to(device).view(view_shape).repeat(repeat_shape)
    new_points = points[batch_indices, idx, :]
    return new_points

这个函数是将索引与得到的位置索引,传到point中,得到实际的特征,得到了每个中心点的特征。接下来就是MSG了:

        for i, radius in enumerate(self.radius_list):
            K = self.nsample_list[i]
            group_idx = query_ball_point(radius, K, xyz, new_xyz)#返回的是索引,以中心点为圆心,画圈
            grouped_xyz = index_points(xyz, group_idx) # 得到各个组中实际点
            grouped_xyz -= new_xyz.view(B, S, 1, C) # 去mean new_xyz相当于簇的中心点,每个值都减去中心点,给人一种去均值的感觉
            if points is not None:
                # 基于关系提特征,要将位置信息以及法向量进行拼接
                grouped_points = index_points(points, group_idx)
                grouped_points = torch.cat([grouped_points, grouped_xyz], dim=-1)
                print(grouped_points.shape)
            else:
                grouped_points = grouped_xyz

半径在上面的__init _ _中已经初始化成功了。那么我们重点看看这个函数query_ball_point:

def query_ball_point(radius, nsample, xyz, new_xyz):
    """
    Input:
        radius: local region radius
        nsample: max sample number in local region
        xyz: all points, [B, N, 3]
        new_xyz: query points, [B, S, 3]
    Return:
        group_idx: grouped points index, [B, S, nsample]
    """
    device = xyz.device
    B, N, C = xyz.shape
    _, S, _ = new_xyz.shape
    #每个组大小规模都一样
    group_idx = torch.arange(N, dtype=torch.long).to(device).view(1, 1, N).repeat([B, S, 1])
    sqrdists = square_distance(new_xyz, xyz)# 得到B N M (就是N个点中每一个和M中每一个的欧氏距离)
    group_idx[sqrdists > radius ** 2] = N # 找到距离大于给定半径的设置成一个N值(1024)索引,表示不在我们半径当中
    group_idx = group_idx.sort(dim=-1)[0][:, :, :nsample]#做升序排序,后面的都是大的值(1024),排序后才可以进行筛选
    group_first = group_idx[:, :, 0].view(B, S, 1).repeat([1, 1, nsample])#如果半径内的点没那么多,就直接用第一个点来代替了,直接就是赋值到我们规定的那个数
    mask = group_idx == N # 剔除不在半径中的点
    group_idx[mask] = group_first[mask]
    return group_idx

最后返回的是每个组的所有用到的点,也就是我们圈中的点。得到我们的中心点后,接下来就要提取特征啦:

        for i, radius in enumerate(self.radius_list):
            K = self.nsample_list[i]
            group_idx = query_ball_point(radius, K, xyz, new_xyz)#返回的是索引,以中心点为圆心,画圈
            grouped_xyz = index_points(xyz, group_idx) # 得到各个组中实际点
            grouped_xyz -= new_xyz.view(B, S, 1, C) # 去mean new_xyz相当于簇的中心点,每个值都减去中心点,给人一种去均值的感觉
            if points is not None:
                # 基于关系提特征,要将位置信息以及法向量进行拼接
                grouped_points = index_points(points, group_idx)
                grouped_points = torch.cat([grouped_points, grouped_xyz], dim=-1)
                print(grouped_points.shape)
            else:
                grouped_points = grouped_xyz
			# 维度转换
            grouped_points = grouped_points.permute(0, 3, 2, 1)  # [B, D, K, S]
            print(grouped_points.shape)
            # 输入几个特征
            for j in range(len(self.conv_blocks[i])):
                conv = self.conv_blocks[i][j]
                bn = self.bn_blocks[i][j]
                # 提特征
                grouped_points =  F.relu(bn(conv(grouped_points)))
            print(grouped_points.shape)
            new_points = torch.max(grouped_points, 2)[0]  # [B, D', S] 就是pointnet里的maxpool操作,取最大值
            print(new_points.shape)
            new_points_list.append(new_points)

最后把我们所有半径找到的特征都拼接在一起:

    new_xyz = new_xyz.permute(0, 2, 1)
    new_points_concat = torch.cat(new_points_list, dim=1)
    print(new_points_concat.shape)
    return new_xyz, new_points_concat

那么经过上面层层的考核,我们终于把前向传播的sa1走完了,接下来就是sa2:

在这里插入图片描述

那么sa1与sa2有什么区别:

在这里插入图片描述

其实就是改变了参数,可以发现这一次的目的是要在512中选128个特征了。之后又要经历我们上面所说的东西了,一模一样的过程,只不过第二次输入进来的特征变成上一次得到的特征数了。

接下来就是sa3了,可以看到这个参数与sa1、sa2不一样,sa3并不是多半径:

他把所有的点都归为一个组,之后进行MLP,进行一次maxpooling,最后得到了我们最终的特征,每个点取最大值,这一点跟pointnet中的操作是一模一样的。

之后进行全连接层,1024->512,512->256,256->40,最终对1024个特征完成40类别的分类,最后softmax一下就可以得到输出结果。

以上就结束了分类任务。那么接下来就轮到分割任务啦:

分割任务需要这么配置:

python train_partseg.py --model pointnet2_part_seg_msg  --normal --log_dir pointnet2_part_seg_msg

在这里插入图片描述

可以发现其参数都是差不多的,大同小异,然后其不一样的地方在于,在参数的规定中,我们都规定了这输入的点的个数为2048个点:

在这里插入图片描述

那么对于要达到这个2048个点的指标,我们也有进行上采样的操作:

在这里插入图片描述

其具体操作其实跟我们的分类任务很像:
在这里插入图片描述

 def forward(self, xyz1, xyz2, points1, points2):
        """
        Input:
            xyz1: input points position data, [B, C, N]
            xyz2: sampled input points position data, [B, C, S]
            points1: input points data, [B, D, N]
            points2: input points data, [B, D, S]
        Return:
            new_points: upsampled points data, [B, D', N]
        """
        xyz1 = xyz1.permute(0, 2, 1)
        xyz2 = xyz2.permute(0, 2, 1)
        print(xyz1.shape)
        print(xyz2.shape)

        points2 = points2.permute(0, 2, 1)
        print(points2.shape)
        B, N, C = xyz1.shape
        _, S, _ = xyz2.shape

        if S == 1:
            # 复制操作,为了达到指标(第一次)
            interpolated_points = points2.repeat(1, N, 1)
            print(interpolated_points.shape)
        else:
            #根据距离插值,把最近的点进行插值扩充
            dists = square_distance(xyz1, xyz2)
            print(dists.shape)
            dists, idx = dists.sort(dim=-1)
            dists, idx = dists[:, :, :3], idx[:, :, :3]  # [B, N, 3]

            dist_recip = 1.0 / (dists + 1e-8)
            norm = torch.sum(dist_recip, dim=2, keepdim=True)
            weight = dist_recip / norm
            print(weight.shape)
            print(index_points(points2, idx).shape)
            
            interpolated_points = torch.sum(index_points(points2, idx) * weight.view(B, N, 3, 1), dim=2)
            print(interpolated_points.shape)
		
        if points1 is not None:
            points1 = points1.permute(0, 2, 1)
            # 拼接操作
            new_points = torch.cat([points1, interpolated_points], dim=-1)
        else:
            new_points = interpolated_points
        print(new_points.shape)
        new_points = new_points.permute(0, 2, 1)
        print(new_points.shape)
        #MLP 找到特征
        for i, conv in enumerate(self.mlp_convs):
            bn = self.mlp_bns[i]
            new_points = F.relu(bn(conv(new_points)))
        print(new_points.shape)
        return new_points

之后得到的这个特征又要进行上采样操作,使用的近距离插值,加权得到我们下一个点,最终得到了2048个点,都走上面那个代码。

之后对我们的点进行做分类任务,也就是相当于分割了:

		feat = F.relu(self.bn1(self.conv1(l0_points)))
        x = self.drop1(feat)
        x = self.conv2(x)
        x = F.log_softmax(x, dim=1)
        x = x.permute(0, 2, 1)

这个就是分割任务的细节实现了,只不过相比于分类任务多了一个上采样。

猜你喜欢

转载自blog.csdn.net/Alkaid2000/article/details/127279733