PyTorch | 广播机制(broadcast)
1. 广播机制定义
\qquad 如果一个PyTorch
操作支持广播,则其Tensor
参数可以自动扩展为相等大小(不需要复制数据)。通常情况下,小一点的数组会被 broadcast
到大一点的,这样才能保持大小一致。
2. 广播机制规则
- 如果遵守以下规则,则两个
tensor
是“可广播的”:- 每个
tensor
至少有一个维度; - 遍历
tensor
所有维度时,从末尾随开始遍历,两个tensor
存在下列情况:tensor
维度相等。tensor
维度不等且其中一个维度为1。tensor
维度不等且其中一个维度不存在。
- 每个
- 如果两个
tensor
是“可广播的”,则计算过程遵循下列规则:- 如果两个
tensor
的维度不同,则在维度较小的tensor
的前面增加维度,使它们维度相等。 - 对于每个维度,计算结果的维度值取两个
tensor
中较大的那个值。 - 两个
tensor
扩展维度的过程是将数值进行复制。
- 如果两个
- 示例:
# 相同维度,一定可以 broadcasting x=torch.ones(5,7,3) y=torch.ones(5,7,3)
# x和y不能被广播,因为x没有符合“至少有一个维度”,所以不可以broadcasting x=torch.ones((0)) y=torch.ones(5,7)
# x 和 y 可以广播 x=torch.ones(5,3,4,1) y=torch.ones( 3,1,1) # 从尾部维度开始遍历 # 1st尾部维度: x和y相同,都为1。 # 2nd尾部维度: y为1,x为4,符合维度不等且其中一个维度为1,则广播为4。 # 3rd尾部维度: x和y相同,都为3。 # 4th尾部维度: y维度不存在,x为5,符合维度不等且其中一个维度不存在,则广播为5。
# x 和 y 不可以广播,因为3rd尾部维度x为2,y为3,不符合维度不等且其中一个维度为1。 x=torch.ones(5,2,4,1) y=torch.ones( 3,1,1)
# x 和 y 可以广播,在维度较小y前面增加维度,使它们维度相等。 x=torch.ones(5,2,4,1) y=torch.ones(1,1) print((x+y).size())
3. in - place 语义
\qquad in-place operation
称为原地操作符,在pytorch
中是指改变一个tensor
的值的时候,不经过复制操作,而是直接在原来的内存上改变它的值。在pytorch
中经常加后缀“_
”来代表原地操作符,例:.add_()
、.scatter()
。in-place
操作不允许tensor
像广播那样改变形状。
示例:
# x 和 y 不可以广播
x=torch.empty(1,3,1)
y=torch.empty(3,1,7)
print((x.add_(y)).size())