PointNet++ code comments

  Reproduced PointNet++, and tried it with the data set I made, and the effect was very good, so I started to read the code, lay the foundation, and find some inspiration and ideas by the way. Some comments in the code are explained based on a dataset I made myself. My dataset is in the format of the ShapeNet dataset. In my dataset, there is only one object: book (book), and the book has two parts. : background (background) and seam (book seam), corresponding to 0 and 1 respectively.


ShapeNetDataLoader.py

  The function of ShapeNetDataLoader.py is to convert n point cloud data into an array. The array has n items, and each item contains point information, a large category of points (book), and a small category of points (background, seam).

# *_*coding:utf-8 *_*
import os
import json
import warnings
import numpy as np
from torch.utils.data import Dataset
warnings.filterwarnings('ignore')

def pc_normalize(pc):
    centroid = np.mean(pc, axis=0)
    pc = pc - centroid
    m = np.max(np.sqrt(np.sum(pc ** 2, axis=1)))
    pc = pc / m
    return pc

class PartNormalDataset(Dataset):
    def __init__(self,root = './data/book_seam_dataset', npoints=50000, split='train', class_choice=None, normal_channel=False):
        self.npoints = npoints
        self.root = root
        self.catfile = os.path.join(self.root, 'synsetoffset2category.txt')
        self.cat = {
    
    }
        self.normal_channel = normal_channel


        with open(self.catfile, 'r') as f:
            for line in f:
                ls = line.strip().split()
                self.cat[ls[0]] = ls[1]
        self.cat = {
    
    k: v for k, v in self.cat.items()} # {'book': '12345678'}
        self.classes_original = dict(zip(self.cat, range(len(self.cat)))) # {'book': 0}

        if not class_choice is  None:
            self.cat = {
    
    k:v for k,v in self.cat.items() if k in class_choice}
        # print(self.cat)

        self.meta = {
    
    }
        with open(os.path.join(self.root, 'train_test_split', 'shuffled_train_file_list.json'), 'r') as f:
            train_ids = set([str(d.split('/')[2]) for d in json.load(f)]) # {'1', '2', ...}
        with open(os.path.join(self.root, 'train_test_split', 'shuffled_val_file_list.json'), 'r') as f:
            val_ids = set([str(d.split('/')[2]) for d in json.load(f)])
        with open(os.path.join(self.root, 'train_test_split', 'shuffled_test_file_list.json'), 'r') as f:
            test_ids = set([str(d.split('/')[2]) for d in json.load(f)])
        for item in self.cat:  # item:'book'
            # print('category', item)
            self.meta[item] = []
            dir_point = os.path.join(self.root, self.cat[item])
            fns = sorted(os.listdir(dir_point))
            # print(fns[0][0:-4])
            if split == 'trainval': # 取训练集+验证集
                fns = [fn for fn in fns if ((fn[0:-4] in train_ids) or (fn[0:-4] in val_ids))] # fn[0:-4]就是‘1.txt’里面的‘1’, fns:['1.txt', '10.txt', ...]
            elif split == 'train':
                fns = [fn for fn in fns if fn[0:-4] in train_ids]
            elif split == 'val':
                fns = [fn for fn in fns if fn[0:-4] in val_ids]
            elif split == 'test':
                fns = [fn for fn in fns if fn[0:-4] in test_ids]
            else:
                print('Unknown split: %s. Exiting..' % (split))
                exit(-1)

            # print(os.path.basename(fns))
            for fn in fns:
                token = (os.path.splitext(os.path.basename(fn))[0]) # os.path.basename删除目录名,保留文件名, token:'1'
                self.meta[item].append(os.path.join(dir_point, token + '.txt')) # {'book': ['data/book_seam_datas...5678/1.txt','...',...}

        self.datapath = []
        for item in self.cat:
            for fn in self.meta[item]:
                self.datapath.append((item, fn))

        self.classes = {
    
    }
        for i in self.cat.keys():
            self.classes[i] = self.classes_original[i]

        # Mapping from category ('Chair') to a list of int [10,11,12,13] as segmentation labels
        self.seg_classes = {
    
    'book': [0, 1]}

        # for cat in sorted(self.seg_classes.keys()):
        #     print(cat, self.seg_classes[cat])

        self.cache = {
    
    }  # from index to (point_set, cls, seg) tuple
        self.cache_size = 20000 # 缓存点的数据,采样点最多不能超过缓存点数量最大值(20000)


    def __getitem__(self, index):
        if index in self.cache:
            point_set, cls, seg = self.cache[index]
        else:
            fn = self.datapath[index] # ('book', 'data/book_seam_datas...5678/5.txt')
            cat = self.datapath[index][0] # 'book'
            cls = self.classes[cat] # [0]
            cls = np.array([cls]).astype(np.int32)
            data = np.loadtxt(fn[1]).astype(np.float32)
            if not self.normal_channel:
                point_set = data[:, 0:3]
            else:
                point_set = data[:, 0:6]
            seg = data[:, -1].astype(np.int32)
            if len(self.cache) < self.cache_size:
                self.cache[index] = (point_set, cls, seg)
        point_set[:, 0:3] = pc_normalize(point_set[:, 0:3])

        choice = np.random.choice(len(seg), self.npoints, replace=True)
        # resample
        point_set = point_set[choice, :]
        seg = seg[choice]

        return point_set, cls, seg # 点的信息,点的大类别(book),点的小类别(background,seam)

    def __len__(self):
        return len(self.datapath)

pointnet2_part_seg_msg.py

  pointnet2_part_seg_msg.py is the overall framework of the entire network. It is built layer by layer by calling the custom network model in pointnet_utils.py, so the specific details of the network cannot be seen.


pointnet2_utils.py

  Various network models and key algorithms of pointnet++ are defined in pointnet2_utils.py.

import torch
import torch.nn as nn
import torch.nn.functional as F
from time import time
import numpy as np

def timeit(tag, t):
    print("{}: {}s".format(tag, time() - t))
    return time()

def pc_normalize(pc):
    l = pc.shape[0]
    centroid = np.mean(pc, axis=0)
    pc = pc - centroid
    m = np.max(np.sqrt(np.sum(pc**2, axis=1)))
    pc = pc / m
    return pc

def square_distance(src, dst):
    
    """
    Calculate Euclid distance between each two points.

    src^T * dst = xn * xm + yn * ym + zn * zm;
    sum(src^2, dim=-1) = xn*xn + yn*yn + zn*zn;
    sum(dst^2, dim=-1) = xm*xm + ym*ym + zm*zm;
    dist = (xn - xm)^2 + (yn - ym)^2 + (zn - zm)^2
         = sum(src**2,dim=-1)+sum(dst**2,dim=-1)-2*src^T*dst
    
    B:batchsize, N:第一组点个数, M:第二组点个数, C:输入点通道数(xyz.C=3)
    Input:
        src: source points, [B, N, C] 
        dst: target points, [B, M, C]
    Output:
        dist: per-point square distance, [B, N, M]
        batchsize个[N,M]
    """
    B, N, _ = src.shape
    _, M, _ = dst.shape
    dist = -2 * torch.matmul(src, dst.permute(0, 2, 1)) # permute:转换维度
    dist += torch.sum(src ** 2, -1).view(B, N, 1) # view:按维度填充
    dist += torch.sum(dst ** 2, -1).view(B, 1, M) # 数组广播机制,右边的式子复制N组后与dist叠加
    return dist


def index_points(points, idx):  # i按照输入的点云数据和索引返回由索引的点云数据。
    """

    Input:
        points: input points data, [B, N, C]
        idx: sample index data, [B, S]
    Return:
        new_points:, indexed points data, [B, S, C]
    """
    device = points.device
    B = points.shape[0]
    view_shape = list(idx.shape)         #view_shape=[B,S]
    view_shape[1:] = [1] * (len(view_shape) - 1)    #[1] * (len(view_shape) - 1) -> [1],即view_shape=[B,1]
    repeat_shape = list(idx.shape)    #repeat_shape=[B,S]
    repeat_shape[0] = 1    #repeat_shape=[1,S]
    #.view(view_shape)=.view(B,1)
    #.repeat(repeat_shape)=.view(1,S)
    #batch_indices的维度[B,S]
    batch_indices = torch.arange(B, dtype=torch.long).to(device).view(view_shape).repeat(repeat_shape)
    new_points = points[batch_indices, idx, :]
    return new_points


def farthest_point_sample(xyz, npoint):
    '''
    FPS的逻辑如下:

        假设一共有n个点,整个点集为N = {f1, f2,…,fn}, 目标是选取n1个起始点做为下一步的中心点:

        随机选取一个点fi为起始点,并写入起始点集 B = {fi};
        选取剩余n-1个点计算和fi点的距离,选择最远点fj写入起始点集B={fi,fj};
        选取剩余n-2个点计算和点集B中每个点的距离, 将最短的那个距离作为该点到点集的距离, 这样得到n-2个到点集的距离,选取最远的那个点写入起始点B = {fi, fj ,fk},同时剩下n-3个点, 如果n1=3 则到此选择完毕;
        如果n1 > 3则重复上面步骤直到选取n1个起始点为止.
    '''
    """
    Input:
        xyz: pointcloud data, [B, N, 3]
        npoint: number of samples
    Return:
        centroids: sampled pointcloud index, [B, npoint]
    """
    device = xyz.device
    B, N, C = xyz.shape     # B:BatchSize, N:ndataset(点云中点的个数), C:dimension
    centroids = torch.zeros(B, npoint, dtype=torch.long).to(device) # 提取得到中心点的集合
    distance = torch.ones(B, N).to(device) * 1e10                   # 记录某个样本中所有点到某一个点的距离,先取很大
    farthest = torch.randint(0, N, (B,), dtype=torch.long).to(device)  # 当前最远的点,随机初始化,范围为0~N,初始化B个,对应到每个样本都随机有一个初始最远点,B列的行向量
    batch_indices = torch.arange(B, dtype=torch.long).to(device)    # batch的索引,0~(B-1)的数组
    for i in range(npoint):
        centroids[:, i] = farthest  # 第i个最远点
        centroid = xyz[batch_indices, farthest, :].view(B, 1, 3)  # 取出最远点xyz坐标
        dist = torch.sum((xyz - centroid) ** 2, -1)      # 计算距离,-1代表行求和
        mask = dist < distance  # 一个bool值的张量数组
        distance[mask] = dist[mask]  # True的会留下,False删除
        farthest = torch.max(distance, -1)[1]  # 返回一个张量,第一项是最大值,第二项是索引,-1代表列索引
    return centroids


def query_ball_point(radius, nsample, xyz, new_xyz):
    '''
    
    '''
    """
    Input:
        radius: local region radius                      # radius为半径,new_xyz为中心,取nsample个点
        nsample: max sample number in local region
        xyz: all points, [B, N, 3]            # 所有点
        new_xyz: query points, [B, S, 3]      # farthest_point_sample得到S个中心点, new_xyz为中心点xyz
    Return:
        group_idx: grouped points index, [B, S, nsample]      # nsameple个点的索引
    """
    device = xyz.device
    B, N, C = xyz.shape
    _, S, _ = new_xyz.shape
    group_idx = torch.arange(N, dtype=torch.long).to(device).view(1, 1, N).repeat([B, S, 1])  # torch.arange得到索引,view转换为三维,repeat使其复制成[B,S,N]
    sqrdists = square_distance(new_xyz, xyz)          # 计算中心点与所有点之间的欧几里德距离
    group_idx[sqrdists > radius ** 2] = N       # 大于半径的点设置成N
    group_idx = group_idx.sort(dim=-1)[0][:, :, :nsample]  # 做升序排列,前面大于radius^2的都是N,会是最大值,所以会直接在剩下的点中取出前nsample个点. 0代表输出值,1代表索引
    # 考虑到有可能前nsample个点中也有被赋值为N的点(即球形区域内不足nsample个点),这种点需要舍弃,直接用第一个点来代替即可
    # group_first: [B, S, nsample], 实际就是把group_idx中的第一个点的值复制到[B, S, nsample]的维度,便利于后面的替换
    # 这里要用view是因为group_idx[:, :, 0]取出之后的tensor相当于二维Tensor,因此需要用view变成三维tensor
    group_first = group_idx[:, :, 0].view(B, S, 1).repeat([1, 1, nsample])
    # 找到group_idx中值等于N的点,会输出0,1构成的三维Tensor,维度为[B,S,nsample]
    mask = group_idx == N
    # 将这些点的值替换为第一个点的值
    group_idx[mask] = group_first[mask]
    return group_idx


def sample_and_group(npoint, radius, nsample, xyz, points, returnfps=False):
    """
    Input:
        npoint:
        radius:
        nsample:
        xyz: input points position data, [B, N, 3]
        points: input points data, [B, N, D]
    Return:
        new_xyz: sampled points position data, [B, npoint, nsample, 3]
        new_points: sampled points data, [B, npoint, nsample, 3+D]
    """
    B, N, C = xyz.shape
    S = npoint
    fps_idx = farthest_point_sample(xyz, npoint) # [B, npoint, C] 中心点
    new_xyz = index_points(xyz, fps_idx)    # 中心点位置
    idx = query_ball_point(radius, nsample, xyz, new_xyz)   # 球查询得到点的索引
    grouped_xyz = index_points(xyz, idx) # [B, npoint, nsample, C]    # 球查询点的位置
    grouped_xyz_norm = grouped_xyz - new_xyz.view(B, S, 1, C)    # 计算与中心点距离

    if points is not None:
        grouped_points = index_points(points, idx)
        new_points = torch.cat([grouped_xyz_norm, grouped_points], dim=-1) # [B, npoint, nsample, C+D] C=3,D为点的特征维度(位置、法向、颜色)
    else:
        new_points = grouped_xyz_norm
    if returnfps:
        return new_xyz, new_points, grouped_xyz, fps_idx
    else:
        return new_xyz, new_points


def sample_and_group_all(xyz, points):
    """
    Input:
        xyz: input points position data, [B, N, 3]
        points: input points data, [B, N, D]
    Return:
        new_xyz: sampled points position data, [B, 1, 3]
        new_points: sampled points data, [B, 1, N, 3+D]
    """
    device = xyz.device
    B, N, C = xyz.shape
    new_xyz = torch.zeros(B, 1, C).to(device)
    grouped_xyz = xyz.view(B, 1, N, C)       #new_xyz代表中心点,用原点表示
    if points is not None:
        new_points = torch.cat([grouped_xyz, points.view(B, 1, N, -1)], dim=-1)
    else:
        new_points = grouped_xyz
    return new_xyz, new_points


class PointNetSetAbstraction(nn.Module):
    def __init__(self, npoint, radius, nsample, in_channel, mlp, group_all):
        super(PointNetSetAbstraction, self).__init__()
        self.npoint = npoint
        self.radius = radius
        self.nsample = nsample
        self.mlp_convs = nn.ModuleList()
        self.mlp_bns = nn.ModuleList()
        last_channel = in_channel
        for out_channel in mlp:
            self.mlp_convs.append(nn.Conv2d(last_channel, out_channel, 1))   # MLP就相当于是1x1卷积
            self.mlp_bns.append(nn.BatchNorm2d(out_channel))
            last_channel = out_channel
        self.group_all = group_all

    def forward(self, xyz, points):
        """
        N是输入点的数量,C是坐标维度(C=3),D是特征维度(除坐标维度以外的其他特征维度)
        S是输出点的数量,C是坐标维度,D'是新的特征维度
        Input:
            xyz: input points position data, [B, C, N]
            points: input points data, [B, D, N]
        Return:
            new_xyz: sampled points position data, [B, C, S]
            new_points_concat: sample points feature data, [B, D', S]
        """
        xyz = xyz.permute(0, 2, 1)  # [B, N, 3]
        if points is not None:
            points = points.permute(0, 2, 1)

        if self.group_all:
            new_xyz, new_points = sample_and_group_all(xyz, points)
        else:
            new_xyz, new_points = sample_and_group(self.npoint, self.radius, self.nsample, xyz, points)
        # new_xyz: sampled points position data, [B, npoint, C]
        # new_points: sampled points data, [B, npoint, nsample, C+D]
        new_points = new_points.permute(0, 3, 2, 1) # [B, C+D, nsample,npoint]    # pytorch的通道顺序是NCHW
        # N - Batch
        # C - Channel
        # H - Height
        # W - Width
        # 对[3+D, nsample]的维度上做逐像素的卷积,结果相当于对单个C+D维度做1d的卷积
        for i, conv in enumerate(self.mlp_convs):
            bn = self.mlp_bns[i]
            new_points =  F.relu(bn(conv(new_points)))

        new_points = torch.max(new_points, 2)[0]
        new_xyz = new_xyz.permute(0, 2, 1)
        return new_xyz, new_points


class PointNetSetAbstractionMsg(nn.Module):    
    def __init__(self, npoint, radius_list, nsample_list, in_channel, mlp_list):
        super(PointNetSetAbstractionMsg, self).__init__()
        self.npoint = npoint
        self.radius_list = radius_list
        self.nsample_list = nsample_list
        self.conv_blocks = nn.ModuleList()
        self.bn_blocks = nn.ModuleList()
        for i in range(len(mlp_list)):
            convs = nn.ModuleList()
            bns = nn.ModuleList()
            last_channel = in_channel + 3
            for out_channel in mlp_list[i]:
                convs.append(nn.Conv2d(last_channel, out_channel, 1))
                bns.append(nn.BatchNorm2d(out_channel))
                last_channel = out_channel
            self.conv_blocks.append(convs)
            self.bn_blocks.append(bns)

    def forward(self, xyz, points):
        """
        Input:
            xyz: input points position data, [B, C, N]
            points: input points data, [B, D, N]
        Return:
            new_xyz: sampled points position data, [B, C, S]
            new_points_concat: sample points feature data, [B, D', S]
        """
        xyz = xyz.permute(0, 2, 1)
        if points is not None:
            points = points.permute(0, 2, 1)

        B, N, C = xyz.shape
        S = self.npoint
        new_xyz = index_points(xyz, farthest_point_sample(xyz, S))
        new_points_list = []
        # 针对多个radius和nsample取点
        for i, radius in enumerate(self.radius_list):
            K = self.nsample_list[i]
            group_idx = query_ball_point(radius, K, xyz, new_xyz)
            grouped_xyz = index_points(xyz, group_idx)
            grouped_xyz -= new_xyz.view(B, S, 1, C)
            if points is not None:
                grouped_points = index_points(points, group_idx)
                grouped_points = torch.cat([grouped_points, grouped_xyz], dim=-1)
            else:
                grouped_points = grouped_xyz

            grouped_points = grouped_points.permute(0, 3, 2, 1)  # [B, D, K, S]
            for j in range(len(self.conv_blocks[i])):
                conv = self.conv_blocks[i][j]
                bn = self.bn_blocks[i][j]
                grouped_points =  F.relu(bn(conv(grouped_points)))
            new_points = torch.max(grouped_points, 2)[0]  # [B, D', S]
            new_points_list.append(new_points)

        new_xyz = new_xyz.permute(0, 2, 1)
        new_points_concat = torch.cat(new_points_list, dim=1)
        return new_xyz, new_points_concat


class PointNetFeaturePropagation(nn.Module):
    def __init__(self, in_channel, mlp):
        super(PointNetFeaturePropagation, self).__init__()
        self.mlp_convs = nn.ModuleList()
        self.mlp_bns = nn.ModuleList()
        last_channel = in_channel
        for out_channel in mlp:
            self.mlp_convs.append(nn.Conv1d(last_channel, out_channel, 1))
            self.mlp_bns.append(nn.BatchNorm1d(out_channel))
            last_channel = out_channel

    def forward(self, xyz1, xyz2, points1, points2):
        """
        Input:
            xyz1: input points position data, [B, C, N]   # 所有点
            xyz2: sampled input points position data, [B, C, S]   # 采样点
            points1: input points data, [B, D, N]
            points2: input points data, [B, D, S]
        Return:
            new_points: upsampled points data, [B, D', N]
        """
        "  将B C N 转换为B N C 然后利用插值将高维点云数目S 插值到低维点云数目N (N大于S)"
        "  xyz1 低维点云  数量为N   xyz2 高维点云  数量为S"
        xyz1 = xyz1.permute(0, 2, 1)
        xyz2 = xyz2.permute(0, 2, 1)
        
        points2 = points2.permute(0, 2, 1)
        B, N, C = xyz1.shape
        _, S, _ = xyz2.shape
        
        "如果最后只有一个点,就将S直复制N份后与与低维信息进行拼接"
        if S == 1:
            interpolated_points = points2.repeat(1, N, 1)
        else:
            dists = square_distance(xyz1, xyz2) # [B,N,S]
            dists, idx = dists.sort(dim=-1)    # 找到距离最近的三个邻居
            dists, idx = dists[:, :, :3], idx[:, :, :3]  # [B,N,3],N个点与这S个距离最近的前三个点的索引

            dist_recip = 1.0 / (dists + 1e-8)    # 求距离的倒数 2,512,3 对应论文中的 Wi(x)
            norm = torch.sum(dist_recip, dim=2, keepdim=True)   # 也就是将距离最近的三个邻居的加起来  此时对应论文中公式的分母部分
            weight = dist_recip / norm   
            """
            这里的weight是计算权重  dist_recip中存放的是三个邻居的距离  norm中存放是距离的和  
            两者相除就是每个距离占总和的比重 也就是weight
            """
            interpolated_points = torch.sum(index_points(points2, idx) * weight.view(B, N, 3, 1), dim=2)  # 点乘

        if points1 is not None:
            points1 = points1.permute(0, 2, 1)
            new_points = torch.cat([points1, interpolated_points], dim=-1)
        else:
            new_points = interpolated_points

        new_points = new_points.permute(0, 2, 1)
        for i, conv in enumerate(self.mlp_convs):
            bn = self.mlp_bns[i]
            new_points = F.relu(bn(conv(new_points)))
        return new_points

  In general, this file mainly implements two network structures: PointNetSetAbstraction and PointNetFeaturePropagation.
PointNetSetAbstractionMsg is just the superposition result of PointNetSetAbstraction using multiple sampling radii.

  The notes in the article refer to the blogger weixin_42707080's PointNet++ series of articles and the article "PointNet++ Upsampling (Feature Propagation)" of the shallow language that is being studied .

Guess you like

Origin blog.csdn.net/astruggler/article/details/128856677