pytorch & numpy广播法则

广播法则

  1. 所有数组向维度最高的数组看齐,若维度不足则在最前面的维度用1补齐
  2. 扩展维度后,所有数组在某一维度相同或者长度为1,否则不能计算
  3. 当可以计算时,将长度为1的维度扩展为另一数组相应维度的长度
a = torch.ones(3, 2)
b = torch.zeros(2,3,1)
a + b
# a : (3, 2)-->(1, 3, 2)
# a : (1, 3, 2)-->(2, 3, 2)
# b : (2, 3, 1)-->(2, 3, 2)
# a + b : (2, 3, 2) 

手工实现广播(建议,较为直观):

a.view(1, 3, 2).expand(2, 3, 2)
b.expand(2, 3, 2)
# repeat和expand功能类似,但是repeat会把数据复制多份,会占用额外空间

猜你喜欢

转载自www.cnblogs.com/lzping/p/12363643.html
今日推荐