插帧中grid_sample函数详解

从之前VSR到后来做MEMC,基本都要用到该函数,但是VSR后期后很多工作很多抛弃了warp操作,因此没有深入研究。但是MEMC是必须用的,否则就要用超级大的网络直接端到端的生成。认准原创https://blog.csdn.net/longshaonihaoa/article/details/125964061

MEMC系列文章:
运动估计运动补偿(Motion estimation and motion compensation,MEMC)入门总结
深度学习MEMC插帧论文列表paper list
光流估计中cost volume详解
插帧中grid_sample函数详解

1、grid_sample基本功能讲解

官方讲解
https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html

函数原型

torch.nn.functional.grid_sample(input, grid, mode='bilinear', padding_mode='zeros', align_corners=None)

参数选择
函数有两个输入项,三个可选参数项。
input:输入,原始图像。维度[B,3,H,W]
grid:映射表。维度[B,H,W,2],值归一化为[-1, 1]
mode: 插值模式,可选双线性‘bilinear’,最近邻‘nearest’。
padding_mode: 补边模式,可选反射‘reflection’,边缘‘border’,零‘zero’。
align_corners: 对齐模式,是否选择对齐。

函数功能:
首先我们区分一下坐标的区别。比如一张图片,坐标是指某个位置,如(2,3)就是指定图像的第2行第3列那个位置。是说这个位置上的像素值。
对应到grid上,他每个坐标处会有两个值,对应的是映射后的坐标。所以grid的最后一维是2,分别对应X,Y。这里XY的值归一化到了[-1,1],在应用是需注意,在函数内部实现中会映射到原始尺寸。下面例子中为了形象讲grid时用非归一化的值。(为啥要归一化,开始我觉得蛮多此一举,最近我看图形学也有类似的归一化,应该有一样的原理?)当对输入图像进行处理时,比如需要处理(2,3)这个坐标。那就查grid中坐标为(2,3)的值,假设为(3,3),那就把原图中(2,3)这个坐标上的值 赋给 输出(3,3)这个坐标。

参数介绍:
padding_mode:当grid的值超出了宽高界限,该怎么选择值。
reflection: 用关于边界的对称点的值,直到坐标落在界内。
border:用边界的值代替
zeros:用0代替。

align_corner: 双线性插值的固有参数,是否对其。
这两个参数在下文代码中会更详细介绍。

2、ATen代码实现

完整的代码可参考官方实现

基本逻辑如下:

# 逐像素循环处理
for (const auto h : c10::irange(out_H)) {
    for (const auto w : c10::irange(out_W)) {
    	...
    	// 对坐标进行处理,接下来会讲这个函数
    	scalar_t ix = grid_sampler_compute_source_index(x, inp_W, padding_mode, align_corners);
        scalar_t iy = grid_sampler_compute_source_index(y, inp_H, padding_mode, align_corners);
        if (interpolation_mode == GridSamplerInterpolation::Bilinear) {
            // 双线性插值操作
            ... 
            }
        else if (interpolation_mode == GridSamplerInterpolation::Nearest) {
            // 最近邻插值操作
            int64_t ix_nearest = static_cast<int64_t>(std::nearbyint(ix));
            int64_t iy_nearest = static_cast<int64_t>(std::nearbyint(iy));
            ...
            }

其实代码中对函数理解最重要的就是grid_sampler_compute_source_index 函数,其代码可见官方地址

扫描二维码关注公众号,回复: 14652865 查看本文章

从以下可以看出它调用了两个函数,一个是unnormalize,一个是计算坐标。

scalar_t grid_sampler_compute_source_index(...) {
    
    
  coord = grid_sampler_unnormalize(coord, size, align_corners);
  coord = compute_coordinates(coord, size, padding_mode, align_corners);
  return coord;
}

unnormalize 实现如下。根据align_corner的设置得到不同运算。当align_corner为True时,原来的[-1,1]映射为[0, size - 1]。False则将[-1, 1] to [-0.5, size - 0.5]。具体代码如下

scalar_t grid_sampler_unnormalize(scalar_t coord, int size, bool align_corners) {
  if (align_corners) {
    // unnormalize coord from [-1, 1] to [0, size - 1]
    return ((coord + 1.f) / 2) * (size - 1);
  } else {
    // unnormalize coord from [-1, 1] to [-0.5, size - 0.5]
    return ((coord + 1.f) * size - 1) / 2;
  }
}

注意align_corner并非只有此处使用。
计算坐标主要是说对padding_mode的处理。主要可以看以下这部分代码:

scalar_t reflect_coordinates(scalar_t in, int twice_low, int twice_high) {
  ...
  scalar_t min = static_cast<scalar_t>(twice_low) / 2;
  scalar_t span = static_cast<scalar_t>(twice_high - twice_low) / 2;
  in = ::fabs(in - min);
  scalar_t extra = ::fmod(in, span);
  int flips = static_cast<int>(::floor(in / span));
  if (flips % 2 == 0) {    // return略有修改,因为我觉得这样更清楚
    return min + extra;
  } else {
    return min + (span - extra);
  }
}

3、CUDA实现

cuda官方实现的核函数在这里
感觉cuda比上面写的更清楚,区别在于没有循环。因为cuda核函数是对某一个位置进行操作的。

4、注意点

grid给定的事归一化的坐标值,而非偏移量。区别在于,坐标值直接通过unnormalize得到目标坐标。而偏移量需要加上当前坐标才能的到目标坐标。

猜你喜欢

转载自blog.csdn.net/longshaonihaoa/article/details/125964061
今日推荐