【代码阅读】PointNet++中的FPS的CUDA实现

文章目录

之前只是使用PointNet++,也没有想过是怎么实现的。之前学了一下cuda编程,这里就来详解一个示例。

本文使用的代码是PointRCNN中PointNet++的实现

Pytorch的接口

FPS的实现是用c和cu实现的,所以先看一下pytorch中的定义。在pointnet2/pointnet2_utils.py中

class FurthestPointSampling(Function):
    @staticmethod
    def forward(ctx, xyz: torch.Tensor, npoint: int) -> torch.Tensor:
        """
        Uses iterative furthest point sampling to select a set of npoint features that have the largest
        minimum distance
        :param ctx:
        :param xyz: (B, N, 3) where N > npoint
        :param npoint: int, number of features in the sampled set
        :return:
             output: (B, npoint) tensor containing the set
        """
        assert xyz.is_contiguous()

        B, N, _ = xyz.size()
        output = torch.cuda.IntTensor(B, npoint)
        temp = torch.cuda.FloatTensor(B, N).fill_(1e10)

        pointnet2.furthest_point_sampling_wrapper(B, N, npoint, xyz, temp, output)
        return output

    @staticmethod
    def backward(xyz, a=None):
        return None, None


furthest_point_sample = FurthestPointSampling.apply

核心函数是furthest_point_sampling_wrapper,这个使用c++写成的。具体怎么链接到cpp,以及这个怎么再变成一个pytorch兼容的函数,具体可见我的另外一篇博客

cpp

代码在pointnet2/src/sampling.cpp中

int furthest_point_sampling_wrapper(int b, int n, int m, 
    at::Tensor points_tensor, at::Tensor temp_tensor, at::Tensor idx_tensor) {
    
    

    const float *points = points_tensor.data<float>();
    float *temp = temp_tensor.data<float>();
    int *idx = idx_tensor.data<int>();

    cudaStream_t stream = THCState_getCurrentStream(state);
    furthest_point_sampling_kernel_launcher(b, n, m, points, temp, idx, stream);
    return 1;
}

可以看到,在cpp中,接收由python函数传入的变量,然后调用cu中的kernel_launcher函数

cu

kernel_launcher函数做的也不多,首先确定开的线程的数量

void furthest_point_sampling_kernel_launcher(int b, int n, int m, 
    const float *dataset, float *temp, int *idxs, cudaStream_t stream) {
    
    
    // dataset: (B, N, 3)
    // tmp: (B, N)
    // output:
    //      idx: (B, M)

    cudaError_t err;
    unsigned int n_threads = opt_n_threads(n);  //计算线程数量,最大为1024

    switch (n_threads) {
    
    
        case 1024:
        //我认为<1024>就是传入开的线程数量的值
        furthest_point_sampling_kernel<1024><<<b, n_threads, 0, stream>>>(b, n, m, dataset, temp, idxs); break;
        case 512:
        furthest_point_sampling_kernel<512><<<b, n_threads, 0, stream>>>(b, n, m, dataset, temp, idxs); break;
        case 256:
        furthest_point_sampling_kernel<256><<<b, n_threads, 0, stream>>>(b, n, m, dataset, temp, idxs); break;
        case 128:
        furthest_point_sampling_kernel<128><<<b, n_threads, 0, stream>>>(b, n, m, dataset, temp, idxs); break;
        case 64:
        furthest_point_sampling_kernel<64><<<b, n_threads, 0, stream>>>(b, n, m, dataset, temp, idxs); break;
        case 32:
        furthest_point_sampling_kernel<32><<<b, n_threads, 0, stream>>>(b, n, m, dataset, temp, idxs); break;
        case 16:
        furthest_point_sampling_kernel<16><<<b, n_threads, 0, stream>>>(b, n, m, dataset, temp, idxs); break;
        case 8:
        furthest_point_sampling_kernel<8><<<b, n_threads, 0, stream>>>(b, n, m, dataset, temp, idxs); break;
        case 4:
        furthest_point_sampling_kernel<4><<<b, n_threads, 0, stream>>>(b, n, m, dataset, temp, idxs); break;
        case 2:
        furthest_point_sampling_kernel<2><<<b, n_threads, 0, stream>>>(b, n, m, dataset, temp, idxs); break;
        case 1:
        furthest_point_sampling_kernel<1><<<b, n_threads, 0, stream>>>(b, n, m, dataset, temp, idxs); break;
        default:
        furthest_point_sampling_kernel<512><<<b, n_threads, 0, stream>>>(b, n, m, dataset, temp, idxs);
    }

    err = cudaGetLastError();
    if (cudaSuccess != err) {
    
    
        fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err));
        exit(-1);
    }
}

接下来看另外一段程序

// block_size就是对应kernel_launcher函数中的<1024>这个
template <unsigned int block_size>
__global__ void furthest_point_sampling_kernel(int b, int n, int m, 
    const float *__restrict__ dataset, float *__restrict__ temp, int *__restrict__ idxs) {
    
    
    // dataset: (B, N, 3)
    // tmp: (B, N)
    // output:
    //      idx: (B, M)

    if (m <= 0) return;
    // 开两个共享内存,dists储存每个线程找到的最远的dists,dists_i储存对应的下标
    __shared__ float dists[block_size];
    __shared__ int dists_i[block_size];

    int batch_index = blockIdx.x;
    // 开的block的数量等于batch,一个block处理一个batch
    // dataset、temp、idxs这些都是指针,加上batch_index就是为了使得指针指向当前block要处理的batch
    dataset += batch_index * n * 3;
    temp += batch_index * n;
    idxs += batch_index * m;

    int tid = threadIdx.x;
    const int stride = block_size;

    int old = 0;
    // FPS总会找到第一个点,就用threadIdx.x=0这个线程处理一下。
    if (threadIdx.x == 0)
    	idxs[0] = old;

    __syncthreads();
    for (int j = 1; j < m; j++) {
    
    
	    int besti = 0;
	    float best = -1;
	    // 把上一次找出的点的坐标拿出来
	    float x1 = dataset[old * 3 + 0];
	    float y1 = dataset[old * 3 + 1];
	    float z1 = dataset[old * 3 + 2];
	    for (int k = tid; k < n; k += stride) {
    
    
	    	// 利用多个线程加速,每个线程处理n/k个点
	        float x2, y2, z2;
	        x2 = dataset[k * 3 + 0];
	        y2 = dataset[k * 3 + 1];
	        z2 = dataset[k * 3 + 2];
	
	        float d = (x2 - x1) * (x2 - x1) + (y2 - y1) * (y2 - y1) + (z2 - z1) * (z2 - z1);
	        // temp大小是[B, N],维护的是每个原始点到已经所有已经选到的点的最小距离
	        float d2 = min(d, temp[k]);
	        temp[k] = d2;
	        besti = d2 > best ? k : besti;
	        best = d2 > best ? d2 : best;
	    }
	    dists[tid] = best;
	    dists_i[tid] = besti;
	    __syncthreads();

		// 以下为找到dists中最大的点
	    if (block_size >= 1024) {
    
    
	        if (tid < 512) {
    
    
	            __update(dists, dists_i, tid, tid + 512);
	        }
	        __syncthreads();
	    }
	
	    if (block_size >= 512) {
    
    
	        if (tid < 256) {
    
    
	            __update(dists, dists_i, tid, tid + 256);
	        }
	        __syncthreads();
	    }
	    if (block_size >= 256) {
    
    
	        if (tid < 128) {
    
    
	            __update(dists, dists_i, tid, tid + 128);
	        }
	        __syncthreads();
	    }
	    if (block_size >= 128) {
    
    
	        if (tid < 64) {
    
    
	            __update(dists, dists_i, tid, tid + 64);
	        }
	        __syncthreads();
	    }
	    if (block_size >= 64) {
    
    
	        if (tid < 32) {
    
    
	            __update(dists, dists_i, tid, tid + 32);
	        }
	        __syncthreads();
	    }
	    if (block_size >= 32) {
    
    
	        if (tid < 16) {
    
    
	            __update(dists, dists_i, tid, tid + 16);
	        }
	        __syncthreads();
	    }
	    if (block_size >= 16) {
    
    
	        if (tid < 8) {
    
    
	            __update(dists, dists_i, tid, tid + 8);
	        }
	        __syncthreads();
	    }
	    if (block_size >= 8) {
    
    
	        if (tid < 4) {
    
    
	            __update(dists, dists_i, tid, tid + 4);
	        }
	        __syncthreads();
	    }
	    if (block_size >= 4) {
    
    
	        if (tid < 2) {
    
    
	            __update(dists, dists_i, tid, tid + 2);
	        }
	        __syncthreads();
	    }
	    if (block_size >= 2) {
    
    
	        if (tid < 1) {
    
    
	            __update(dists, dists_i, tid, tid + 1);
	        }
	        __syncthreads();
	    }
	
	// 找到dist最大的一个,作为本次循环选出的点
    old = dists_i[0];
    if (tid == 0)
        idxs[j] = old;
    }
}

猜你喜欢

转载自blog.csdn.net/wqwqqwqw1231/article/details/107451089