Numpy中的广播(Broadcasting)

Numpy的Universal functions 中要求输入的数组shape是一致的,当数组的shape不想等的时候,则会使用广播机制,调整数组使得shape一样,满足规则,则可以运算,否则就出错 
四条规则如下:

All input arrays with ndim smaller than the input array of largest ndim, have 1’s prepended to their shapes.
The size in each dimension of the output shape is the maximum of all the input sizes in that dimension.
An input can be used in the calculation if its size in a particular dimension either matches the output size in that dimension, or has value exactly 1.
If an input has a dimension size of 1 in its shape, the first data entry in that dimension will be used for all calculations along that dimension. In other words, the stepping machinery of the ufunc will simply not step along that dimension (the stride will be 0 for that dimension).
中文

让所有输入数组都向其中shape最长的数组看齐,shape中不足的部分都通过在前面加1补齐
输出数组的shape是输入数组shape的各个轴上的最大值
如果输入数组的某个轴和输出数组的对应轴的长度相同或者其长度为1时,这个数组能够用来计算,否则出错
当输入数组的某个轴的长度为1时,沿着此轴运算时都用此轴上的第一组值
以下通过实例来说明这些问题

一般情况下,numpy 都是采用一一对应的方式(element-by-element )进行计算

例子1:

>>> from numpy import array 
>>> a = array([1.0,2.0,3.0])
>>> b = array([2.0,2.0,2.0])
>>> a * b
array([ 2.,  4.,  6.])
1
2
3
4
5
当不相等时,则会采用规则对其:

>>> from numpy import array
>>> a = array([1.0,2.0,3.0])
>>> b = 2.0
>>> a * b
array([ 2.,  4.,  6.])
1
2
3
4
5
a.shape得到的是(3,) b是一个浮点数,如果转换成array,则b.shape是一个(),a的1轴对齐,补齐为1,a.shape(3,1),b对齐,则对齐也为(3,1),然后按照一一对应的方式计算

或许上述例子不是太明确,下面采用一个更加确切的例子说明:

>>> import numpy as np
>>> a = np.arange(0, 6).reshape(6, 1)
>>> a
array([[ 0], [1], [2], [3], [4], [5]])
>>> a.shape
(6, 1)
>>> b = np.arange(0, 5)
>>> b.shape
(5,)
>>> c = a + b
>>> print c
[[0 1 2 3 4]
 [1 2 3 4 5]
 [2 3 4 5 6]
 [3 4 5 6 7]
 [4 5 6 7 8]
 [5 6 7 8 9]]
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
在上述实例中,当使用+运算时,由于shape不一致,按照规则1,会对b进行 
reshape,b.reshape成(1,5),可能会问为什么不是(5,1),因为这个就不能计算了,那么如果b的shape是(6,)的时候呢,都可以运算;所以对于对齐本身我自己是没有理解太过透彻,所以,我找了一下官方文档,其中的一个图是这样的: 
 
我理解是,其本身的形状是不能改变的,只能在原来的基础上延伸,像上述的例子中,如果b的shape是(6,),如果在broadcasting的时候reshape(6,1)则已经是属于改变了原来的数组的形状,进行了翻转,而不是延伸。

接着上述实例,对于b则reshape成了(1,5),a则保持(6,1),按照规则2,则输出为每个轴上的最大值,则c.shape为(6,5);

对于规则3和规则4,都是在描述延伸的条件和方式,所以对于我的理解我也更加确信了,如果有大侠觉得有问题,请帮忙指正

参考:

NumPy-快速处理数据(http://old.sebug.net/paper/books/scipydoc/numpy_intro.html) 
Array Broadcasting in numpy(http://scipy.github.io/old-wiki/pages/EricsBroadcastingDoc) 
NumPy Reference(http://docs.scipy.org/doc/numpy-1.10.0/reference/ufuncs.html)
--------------------- 

参考原文:https://blog.csdn.net/yangnanhai93/article/details/50127747 
 

猜你喜欢

转载自blog.csdn.net/mdjxy63/article/details/85040094
今日推荐