Addition and subtraction/logical operation problems between tensors with inconsistent dimensions in pytorch

Rule 1: If the dimensions of the two tensors to be added are inconsistent, first align the tensor with the lower dimension with the tensor with the higher dimension from the right

For example, in the following code, the dimension of b is lower, so when it is added to a, the dimension of b will first be expanded to [1,1,5,6].

a = torch.ones([8, 4, 5, 6])
b = torch.ones([5, 6])
c = a+b
print('c =',c.size())
# c = torch.Size([8, 4, 5, 6])

After alignment, you can jump to rule 2.

​Rule 2: When two tensors have the same dimension, the values ​​of the corresponding axes need to be the same, or 1.

When adding, copy and expand all axes that are 1, so as to obtain two tensors with exactly the same dimensions. Then the corresponding positions can be added.

Examples that can be added:

1. Since each corresponding axis is either equal or one of them is 1, they can be added. Otherwise it cannot be added.

a = torch.ones([8, 4, 5, 6])
b = torch.ones([1, 1, 5, 6])
c = a+b
# c = torch.Size([8, 4, 5, 6])

further:

a = torch.ones([5, 1, 1, 5])
b = torch.ones([5, 5])
c = a+b
print('c =', c.size())
# c = torch.Size([5, 1, 5, 5])

Here, the dimension of a is (5,1,1,5), the dimension of b is (5,5), and the dimension of c of the final result is (5,1,5,5), where the addition operation is performed on both a and b When the dimension expansion
operation is performed, first align b with a from the right according to rule 1, and the dimension of b becomes (1,1,5,5),
and then according to rule 2 , the dimensions of a and b are the same, and the operation can be performed. In the actual operation, both a and b will copy and expand the axis with a dimension of 1, and the dimensions will become (5,1,5,5), so the addition operation can be performed

2. One dimension is all 1, which can be added

a = torch.ones([8, 4, 5, 6])
print('a =',a.size())
b = torch.ones([1, 1, 1, 1])
print('b =',b.size())
c = a+b
print('c =',c.size())
# c = torch.Size([8, 4, 5, 6])

3. The dimensions are exactly equal and can be added

a = torch.ones([8, 4, 5, 6])
print('a =',a.size())
b = torch.ones([8, 4, 5, 6])
print('b =',b.size())
c = a+b
print('c =',c.size())
# c = torch.Size([8, 4, 5, 6])

Examples that cannot be added :

1. Since 4 is not equal to 2, it cannot be added

a = torch.ones([8, 4, 5, 6])
print('a =',a.size())
b = torch.ones([1, 2, 1, 6])
print('b =',b.size())
c = a+b
print('c =',c.size())

2. Since 3 is not equal to 6, it cannot be added

a = torch.ones([8, 4, 5, 6])
print('a =',a.size())
b = torch.ones([1, 4, 1, 3])
print('b =',b.size())
c = a+b
print('c =',c.size())

Guess you like

Origin blog.csdn.net/ytusdc/article/details/128106265