PyTorch | 广播机制(broadcast)

PyTorch | 广播机制(broadcast)

1. 广播机制定义

\qquad 如果一个PyTorch操作支持广播,则其Tensor参数可以自动扩展为相等大小(不需要复制数据)。通常情况下,小一点的数组会被 broadcast 到大一点的,这样才能保持大小一致。

2. 广播机制规则

  • 如果遵守以下规则,则两个tensor是“可广播的”:
    • 每个tensor至少有一个维度;
    • 遍历tensor所有维度时,从末尾随开始遍历,两个tensor存在下列情况:
      • tensor维度相等。
      • tensor维度不等且其中一个维度为1。
      • tensor维度不等且其中一个维度不存在。
  • 如果两个tensor是“可广播的”,则计算过程遵循下列规则:
    • 如果两个tensor的维度不同,则在维度较小的tensor的前面增加维度,使它们维度相等。
    • 对于每个维度,计算结果的维度值取两个tensor中较大的那个值。
    • 两个tensor扩展维度的过程是将数值进行复制。
  • 示例:
    # 相同维度,一定可以 broadcasting
    x=torch.ones(5,7,3)
    y=torch.ones(5,7,3)
    
    在这里插入图片描述
    # x和y不能被广播,因为x没有符合“至少有一个维度”,所以不可以broadcasting
    x=torch.ones((0))
    y=torch.ones(5,7)
    
    在这里插入图片描述
    # x 和 y 可以广播
    x=torch.ones(5,3,4,1)
    y=torch.ones(  3,1,1)
    # 从尾部维度开始遍历
    # 1st尾部维度: x和y相同,都为1。
    # 2nd尾部维度: y为1,x为4,符合维度不等且其中一个维度为1,则广播为4。
    # 3rd尾部维度: x和y相同,都为3。
    # 4th尾部维度: y维度不存在,x为5,符合维度不等且其中一个维度不存在,则广播为5。
    
    在这里插入图片描述
    # x 和 y 不可以广播,因为3rd尾部维度x为2,y为3,不符合维度不等且其中一个维度为1。
    x=torch.ones(5,2,4,1)
    y=torch.ones(  3,1,1)
    
    在这里插入图片描述
    # x 和 y 可以广播,在维度较小y前面增加维度,使它们维度相等。
    x=torch.ones(5,2,4,1)
    y=torch.ones(1,1)
    print((x+y).size())
    
    在这里插入图片描述

3. in - place 语义

\qquad in-place operation称为原地操作符,在pytorch中是指改变一个tensor的值的时候,不经过复制操作,而是直接在原来的内存上改变它的值。在pytorch中经常加后缀“_”来代表原地操作符,例:.add_().scatter()in-place操作不允许tensor像广播那样改变形状。
示例:

# x 和 y 不可以广播
x=torch.empty(1,3,1)
y=torch.empty(3,1,7)
print((x.add_(y)).size())

在这里插入图片描述

猜你喜欢

转载自blog.csdn.net/m0_52650517/article/details/119913625