PixelShuffle详解和cuda实现

1.背景

1.1 PixelShuffle的出处

PixelShuffle这一操作出自论文Real-Time Single Image and Video Super-Resolution Using an Efficient Sub-Pixel Convolutional Neural Network, 论文中称为periodic shuffling operator。Pytorch将其实现为

torch.nn.PixleShuffle(upscale_factor)

1.2为什么要用cuda实现

  • 在将包含PixelShuffle的模型向trt转换时,如果特征图很大的时候,例如 ( 1 , 64 ∗ 4 , 1088 , 1920 ) (1, 64*4, 1088, 1920) (1,644,1088,1920) ( 1 , 64 , 1088 ∗ 2 , 1920 ∗ 2 ) (1, 64, 1088*2, 1920*2) (1,64,10882,19202)转换时,会出现out of memory(显卡的显存是32G)。因此,只能单独拿出来用cuda实现;
  • 用cuda实现了发现,会比pytorch的要快一点;

2.什么是PixelShuffle

首先贴一下论文中的图:
在这里插入图片描述
图中彩色部分从 r 2 r^2 r2channels->High-resolution image的示意过程即为PixelShuffle;
PixelShuffle可以看成一个特殊的reshape操作,其通过从通道维度向长宽维度搬移像素,实现上采样,因此可用于SR等需要将特征图放大的task;

  • input shape
    ( N , C ∗ r ∗ r , H , W ) (N, C*r*r, H, W) (N,Crr,H,W)
  • output shape
    ( N , C , r ∗ H , r ∗ W ) (N, C, r*H, r*W) (N,C,rH,rW)
    其中r是想要放大的倍数;

3. PixelShuffle cuda实现

  • 首先看一下论文中对此操作的公式定义:
    在这里插入图片描述
    其中y和x是结果在rH,rW维度上的坐标,对应的是原特征图上的 y%r 和 x%r,这很好理解;
    但是这个公示的c的表达式,只有在C=1的时候是正确的吧?我是没太看懂,希望有看懂的同学指点一下;
  • 由于照着论文公式实现有问题,我就自己总结了公式如下:
    p s ( T ) h , w , c = T ⌊ h / r ⌋ , ⌊ w / r ⌋ , c ∗ r 2 + r ∗ ⌊ h / r ⌋ + ⌊ w / r ⌋ ps(T)_{h, w, c} = T_{ {\lfloor h/r \rfloor}, {\lfloor w/r \rfloor}, c*r^2+r*{\lfloor h/r \rfloor}+{\lfloor w/r \rfloor}} ps(T)h,w,c=Th/r,w/r,cr2+rh/r+w/r
  • cuda代码如下:
__global__ void pixel_shuffle_kernel(const half *x, half *z, int r, int w, int h, int c, int input_c_stride, int input_h_stride, int output_c_stride, int output_h_stride)
{
    const int w_i = blockIdx.x * blockDim.x + threadIdx.x;
	const int h_i = blockIdx.y * blockDim.y + threadIdx.y;
	const int c_i = blockIdx.z * blockDim.z + threadIdx.z;
    const bool withinXbounds = w_i < w;
	const bool withinYbounds = h_i < h;
    const bool withinCbounds = c_i < c;
    if(withinXbounds && withinYbounds && withinCbounds){
        long ic = r*(h_i%r) + (w_i%r) + c_i*r*r;
        long iw = w_i/r;
        long ih = h_i/r;
        long index = 0+ic*(long)input_c_stride+ih*(long)input_h_stride+(long)iw;
        z[0+c_i*output_c_stride+h_i*output_h_stride+w_i] = x[index];
    }
}

经验证,此kernel的结果和torch.nn.PixleShuffle结果一致。

猜你喜欢

转载自blog.csdn.net/BigerBang/article/details/108551305
今日推荐