高阶 numpy 数组快速插值(高阶快插)算法探讨

版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。
本文链接: https://blog.csdn.net/xufive/article/details/94718525

在科学计算和数据处理领域,数据插值是我们经常面对的问题。尽管 numpy 自身提供了 numpy.interp 插值函数,但只能做一维线性插值,因此,在实际工作中,我们更多地使用 scipy 的 interpolate 子模块。关于 numpy 和 scipy 的关系,有兴趣的话,可以参考拙作《数学建模三剑客MSN》

遗憾的是,scipy.interpolate 只提供了一维和二维的插值算法,而大名鼎鼎的商业软件 Matlab 则有三维插值函数可用。事实上,三维乃至更高阶的插值需求还是挺普遍的。比如,三维体数据绘制时,为了增强显示效果,让数据体看起来更细腻,三维插值是必不可少的。下图是三维数据插值前后的3D显示效果对比(使用 pyopengl 绘制)。
在这里插入图片描述三维乃至更高阶的数据插值,简称高阶快插,通常都是线性插值。最自然的想法,是循环调用 scipy.interpolate 的 interp1d 或 interp2d 函数实现三维插值,但是循环调用的效率低得无法忍受,完全体现不出 numpy 的广播和矢量化的特点。

网上有人提出在_fitpack模块中使用_spline属性和低级别的_bspleval()函数,以实现高阶快插。详情见《python – 3D阵列上的快速插值》。文中的代码我费了九牛二虎之力仍然没有跑通。

那么,能否借助 numpy 的广播和矢量化的特点,实现高阶快插呢?经过验证,这个思路是完全可行的。我们先从一维线型插值开始讨论。

  1. 使用元素重复函数 repeat() 将长度为 n 的一维数组扩展成长度为 2n-1 的数组,排在偶数位置的元素是前面元素的重复;
  2. 所有偶数位置的元素减去前后元素差的一半。

我们来验证一下:

import numpy as np
>>> a = np.arange(5, dtype=np.float)
>>> a
array([0., 1., 2., 3., 4.])
>>> a = a.repeat(2)[:-1]
>>> a
array([0., 0., 1., 1., 2., 2., 3., 3., 4.])
>>> a[1::2] += (a[2::2]-a[1::2])/2
>>> a
array([0. , 0.5, 1. , 1.5, 2. , 2.5, 3. , 3.5, 4. ])

结果符合预期,完全矢量化!进一步验证,还可以证明该方法对 numpy.nan 类型的数据处理方式符合我们的期望。

>>> a = np.arange(5, dtype=np.float)
>>> a[2] = np.nan
>>> a
array([ 0.,  1., nan,  3.,  4.])
>>> a = a.repeat(2)[:-1]
>>> a
array([ 0.,  0.,  1.,  1., nan, nan,  3.,  3.,  4.])
>>> a[1::2] += (a[2::2]-a[1::2])/2
>>> a
array([0. , 0.5, 1. , nan, nan, nan, 3. , 3.5, 4. ])

接下来,我们就可以尝试用这个思路实现三维数组的线性插值了。代码如下:

import numpy as np

def interp3d(arr_3d):
    """三维数组线性插值
    
    arr_3d      - numpyp.ndarray类型的三维数组
    """
    
    layers, rows, cols = arr_3d.shape
    
    arr_3d = arr_3d.repeat(2).reshape((layers*rows, -1))
    arr_3d = arr_3d.repeat(2, axis=0).reshape((layers, -1))
    arr_3d = arr_3d.repeat(2, axis=0).reshape((layers*2, rows*2, cols*2))[:-1, :-1, :-1]
    
    
    arr_3d[:,:,1::2] += (arr_3d[:,:,2::2]-arr_3d[:,:,1::2])/2
    arr_3d[:,1::2,:] += (arr_3d[:,2::2,:]-arr_3d[:,1::2,:])/2
    arr_3d[1::2,:,:] += (arr_3d[2::2,:,:]-arr_3d[1::2,:,:])/2
    
    return arr_3d

if __name__ == '__main__':
    import time
    
    arr_3d = np.random.randn(100, 200, 300)
    print(u'插值前数组的shape:', arr_3d.shape)
    
    t0 = time.time()
    arr_3d = interp3d(arr_3d)
    t1 = time.time()
    
    print(u'插值后数组的shape:', arr_3d.shape)
    print(u'耗时%.03f秒'%(t1-t0,))

运行结果如下:

PS D:\XufiveGit\interp3d> py -3 .\test.py
插值前数组的shape: (100, 200, 300)
插值后数组的shape: (199, 399, 599)
耗时1.281秒

在我的认知范围内,这应该是目前最快的高阶快插算法了。虽殚精竭虑而后得,亦弗敢专也,必以分享于同好。

猜你喜欢

转载自blog.csdn.net/xufive/article/details/94718525