numpy 广播

numpy的universal function处理数据时,要求输入数组的shape必须一致,当数组的shape不一致时,则会产生广播机制;

广播机制会调整shape,使数组运算满足规则。

广播机制在调整ndarray时的四条规则:

 1 让所有输入数组都向其中shape最长的ndarray看齐,shape中不足的部分都通过在前面添加1补齐

 2 输出数组的shape是输入数组shape的各个轴上的最大值

 3 如果输入数组的某个轴和输出数组的对应轴的长度相同或者其长度为1时,这个数组能够用来计算,否则出错

 4 当输入数组的某个轴的长度为1时,沿着此轴运算时都用此轴上的第一组值

实例一

import numpy as np
a = np.array([[10,10,10],[20,20,20],[30,30,30]])

print(a.shape) #(3, 3)
print(a)
# [[10 10 10]
#  [20 20 20]
#  [30 30 30]]

b = np.array([1,2,3])
print(b.shape) #(3,)
print(b)
# [1 2 3]

c = a + b
print(c.shape) #(3, 3)
print(c)
# [[11 12 13]
#  [21 22 23]
#  [31 32 33]]

解析

1 这里最长的是a,shape=(3,3),b的shape为1行3列 (3,),对于a而言,b 的行是不足的,因此补足后为

b.shape = 1,3
print(b.shape) #(1, 3)
print(b) #[[1 2 3]]

2 输出数组的shape为输入数组shape在各轴上的最大值,也即(3,3)

 
 

猜你喜欢

转载自www.cnblogs.com/gengyi/p/9206542.html
今日推荐