在python中,Numpy与Pytorch都支持广播。
广播的原则:如果两个数组的后缘维度(trailing dimension,即从末尾开始算起的维度)的轴长度相符,或其中的一方的长度为1,则认为它们是广播兼容的。广播会在缺失和(或)长度为1的维度上进行。
说的有些复杂,举几个例子:
如果两个数组 a 和 b 形状相同,即满足 a.shape == b.shape,那么 a*b 的结果就是 a 与 b数组对应位相乘。这要求维数相同,且各维度的长度相同。
import numpy as np
a = np.array([1, 2, 3, 4])
b = np.array([10, 20, 30, 40])
c = a * b
print(c)
>>
[ 10 40 90 160]
arr1的shape为(4,3),arr2的shape为(3,)。arr2的shape类似于(1,3),所以说它们的后缘维度相等,arr1的第二维长度为3,和arr2的维度相同。arr1和arr2的shape并不一样,但是它们可以执行相加操作,这就是通过广播完成的,在这个例子当中是将arr2沿着0轴进行扩展。
import numpy as np
arr1 = np.array([[0, 0, 0],[1, 1, 1],[2, 2, 2], [3, 3, 3]])#(4,3)
arr2 = np.array([1, 2, 3]) #相当于(1,3)
arr_sum = arr1 + arr2
print(arr_sum)
>>
[[1 2 3]
[2 3 4]
[3 4 5]
[4 5 6]]
图例:
再来一组图例:
显然,是将两个数组按最后的维度对齐,然后判定长度是否相等或至少有一个为1,如果哪个维度没有长度就看成1。从后往前,每一个维度都这样做,如果有前面说的例外的情况就判定为不能广播。
例如:
(3,4,2)+(4,2),从后往前,2等于2,4等于4,接下来是没有的长度看成1,3与1至少有一个1,所以可以广播,(4,2)广播的方式就是3*(4,2)
(4,3)+(4,1),从后往前,1与3至少有一个1,4等于4,所以可以广播。(4,1)的广播方式是(4,1)*3
(4,3)+(4,),(4,)可以看作(1,4)。从后往前,4不等于3,所以不能广播,写这样的代码执行后会报错。
再来看一个pytorch的例子:
import torch
x = torch.tensor([[[[1,2],[1,2],[1,2]]]])#1 1 3 2
y=torch.tensor([[-2],[4],[-8]])#3 1
print(x.shape)
print(y.shape)
print(x+y)
>>
torch.Size([1, 1, 3, 2])
torch.Size([3, 1])
tensor([[[[-1, 0],
[ 5, 6],
[-7, -6]]]])
(1,1,3,2)+(3,1),从后往前,1与2至少有一个1,3等于3,接下来是没有的长度看成1,1与1至少有一个1,再接下来是没有的长度看成1,1与1至少有一个1,所以可以广播。
【有问题请直说】