Pytorch中的广播机制(Broadcast)

1. 广播机制定义

如果一个PyTorch操作支持广播,则其Tensor参数可以自动扩展为相等大小(不需要复制数据)。通常情况下,小一点的数组会被 broadcast 到大一点的,这样才能保持大小一致。

2. 广播机制规则

2.1 如果遵守以下规则,则两个tensor是“可广播的”:

  • 每个tensor至少有一个维度
  • 遍历tensor所有维度时,从末尾开始遍历(从右往左开始遍历)(从后往前开始遍历),两个tensor存在下列情况:
    • tensor维度相等
    • tensor维度不等且其中一个维度为1
    • tensor维度不等且其中一个维度不存在

2.2 如果两个tensor是“可广播的”,则计算过程遵循下列规则:

  • 如果两个tensor的维度不同,则在维度较小的tensor的前面增加维度,使它们维度相等
  • 对于每个维度,计算结果的维度值取两个tensor中较大的那个值
  • 两个tensor扩展维度的过程是将数值进行复制

3.代码举例

3.1 相同维度,一定可以 broadcasting。

# 相同维度,一定可以 broadcasting
x=torch.ones(5,7,3)
y=torch.ones(5,7,3)
z = x+y
x.shape,y.shape,z.shape
输出结果如下:
(torch.Size([5, 7, 3]), torch.Size([5, 7, 3]), torch.Size([5, 7, 3]))

3.2 x和y不能被广播,因为x没有符合“至少有一个维度”,所以不可以broadcasting。

# x和y不能被广播,因为x没有符合“至少有一个维度”,所以不可以broadcasting
x=torch.ones((0,))
y=torch.ones(5,7,3)
z = x+y
x.shape,y.shape,z.shape

x,y不能进行广播
3.3 x 和 y 可以广播。

# x 和 y 可以广播
x=torch.ones(5,3,4,1)
y=torch.ones(  3,1,1)
z = x+y
x.shape,y.shape,z.shape
# 从尾部维度开始遍历
# 1st尾部维度: x和y相同,都为1。
# 2nd尾部维度: y为1,x为4,符合维度不等且其中一个维度为1,则广播为4。
# 3rd尾部维度: x和y相同,都为3。
# 4th尾部维度: y维度不存在,x为5,符合维度不等且其中一个维度不存在,则广播为5。
输出结果如下:
(torch.Size([5, 3, 4, 1]), torch.Size([3, 1, 1]), torch.Size([5, 3, 4, 1]))

3.4 x 和 y 不可以广播,因为倒数第三维度x为2,y为3,不符合维度不等且其中一个维度为1。

# x 和 y 不可以广播,因为倒数第三维度x为2,y为3,不符合维度不等且其中一个维度为1。
x=torch.ones(5,2,4,1)
y=torch.ones(  3,1,1)
z = x+y
x.shape,y.shape,z.shape

x,y不能进行广播
3.5 x 和 y 可以广播,在维度较小y前面增加维度,使它们维度相等,同时使他们维度大小相同。

# x 和 y 可以广播,在维度较小y前面增加维度,使它们维度相等。
x=torch.ones(5,2,4,1)
y=torch.ones(1,1)
z = x+y
x.shape,y.shape,z.shape
输出结果如下:
(torch.Size([5, 2, 4, 1]), torch.Size([1, 1]), torch.Size([5, 2, 4, 1]))

4. in - place 语义

in-place operation称为原地操作符,在pytorch中是指改变一个tensor的值的时候,不经过复制操作,而是直接在原来的内存上改变它的值。在pytorch中经常加后缀“”来代表原地操作符,例:.add _()、.scatter(),in-place操作不允许tensor使用广播机制那样来改变张量形状维度大小,如下例子所示。

# x 和 y 不可以广播
x=torch.empty(1,3,1)
y=torch.empty(3,1,7)
z = x.add_(y)
x.shape,y.shape,z.shape

使用in-place原地操作符

猜你喜欢

转载自blog.csdn.net/flyingluohaipeng/article/details/125109094