pytorch & numpy broadcasting law

Broadcasting law

  1. All arrays in line to the top of the array dimensions, if the dimension is less than the dimensions of the front filled with 1
  2. After extended dimension, all arrays in the same or a dimension of length 1, or can not be calculated
  3. When you can calculate the length of an extended length dimension of the corresponding dimension of the other array
a = torch.ones(3, 2)
b = torch.zeros(2,3,1)
a + b
# a : (3, 2)-->(1, 3, 2)
# a : (1, 3, 2)-->(2, 3, 2)
# b : (2, 3, 1)-->(2, 3, 2)
# a + b : (2, 3, 2) 

Manual for broadcast (recommended, more intuitive):

a.view(1, 3, 2).expand(2, 3, 2)
b.expand(2, 3, 2)
# repeat和expand功能类似,但是repeat会把数据复制多份,会占用额外空间

Guess you like

Origin www.cnblogs.com/lzping/p/12363643.html
law