通俗易懂的解释numpy中的广播

广播是numpy对不同shape的array进行数值计算的方式,符合一定规则的前提下,将较小的array“广播”成更大的、可以计算的array。广播意味着一种向量化操作,从而在类似C语言中产生大量循环,这会导致内存和计算效率的低效。在Python中,广播不会做大量的数据复制并且通常使计算更加高效。

标准的数组计算形式为两个shape形状一样:

>>> a = np.array([1.0, 2.0, 3.0])
>>> b = np.array([2.0, 2.0, 2.0])
>>> a * b
array([ 2.,  4.,  6.])

如果shape不一致,需要满足一定规则,可以使用广播。

广播规则

网上很多文章的写法都是:

  1. 让所有输入数组都向其中shape最长的数组看齐,shape中不足的部分都通过在前面加1补齐
  2. 输出数组的shape是输入数组shape的各个轴上的最大值
  3. 如果输入数组的某个轴和输出数组的对应轴的长度相同或者其长度为1时,这个数组能够用来计算,否则出错
  4. 当输入数组的某个轴的长度为1时,沿着此轴运算时都用此轴上的第一组值

我看的时候看了很久都没有看懂。其实可以更简单的描述:
对两个数组,分别比较他们的每一个维度(若其中一个数组没有当前维度则忽略),满足:

  1. 当前维度的值相等
  2. 当前维度的值有一个是1

若条件不满足,抛出“ValueError: frames are not aligned”异常。
输出数组的维度是每一个维度的最大值,广播将值为1的维度进行“复制”、“拉伸”,如图所示


需要注意的是这里的“复制”只是一个抽象概念,Python并不会对数据进行真实复制。

举例:

Image  (3d array): 256 x 256 x 3
Scale  (1d array):             3
Result (3d array): 256 x 256 x 3

在第三个维度相等

A      (4d array):  8 x 1 x 6 x 1
B      (3d array):      7 x 1 x 5
Result (4d array):  8 x 7 x 6 x 5

在第2、3、4维度满足第2个条件

A      (2d array):      2 x 1
B      (3d array):  8 x 4 x 3

错误的情况,在第二个维度不满足条件。

参考资料
numpy文档 :https://docs.scipy.org/doc/numpy-1.13.0/user/basics.broadcasting.html
图形化描述:http://scipy.github.io/old-wiki/pages/EricsBroadcastingDoc

猜你喜欢

转载自blog.csdn.net/xiang_freedom/article/details/77968164