pytorch で次元が一致しないテンソル間の加算と減算/論理演算の問題

ルール 1: 追加する 2 つのテンソルの次元が一致しない場合は、まず、低い次元のテンソルと高い次元のテンソルを右から揃えます。

たとえば、次のコードでは、b の次元が低いため、a に追加されると、まず b の次元が [1,1,5,6] に拡張されます。

a = torch.ones([8, 4, 5, 6])
b = torch.ones([5, 6])
c = a+b
print('c =',c.size())
# c = torch.Size([8, 4, 5, 6])

調整後、ルール 2 にジャンプできます。

ルール 2: 2 つのテンソルが同じ次元を持つ場合、対応する軸の値は同じか 1 である必要があります。

追加するときは、まったく同じ次元の 2 つのテンソルを取得するために、1 であるすべての軸をコピーして展開します。次に、対応する位置を追加します。

追加できる例:

1. 対応する各軸が等しいか、どちらかが 1 なので加算できます。それ以外の場合は追加できません。

a = torch.ones([8, 4, 5, 6])
b = torch.ones([1, 1, 5, 6])
c = a+b
# c = torch.Size([8, 4, 5, 6])

さらに遠く:

a = torch.ones([5, 1, 1, 5])
b = torch.ones([5, 5])
c = a+b
print('c =', c.size())
# c = torch.Size([5, 1, 5, 5])

ここで、a の次元は (5,1,1,5)、b の次元は (5,5)、最終結果の c の次元は (5,1,5,5) です。 aとbの両方に対して演算を行います。 次元拡張
演算を行う場合は、ルール1に従ってbをaに右から並べてbの次元が(1,1,5,5)となり、
その後ルールに従います。 2では a と b の次元が同じなので演算が可能ですが、実際の演算では a も b も次元 1 の軸をコピーして展開し、次元は (5,1, 5,5)、加算演算を実行できます。

2. 1 つの次元はすべて 1 であり、追加することができます。

a = torch.ones([8, 4, 5, 6])
print('a =',a.size())
b = torch.ones([1, 1, 1, 1])
print('b =',b.size())
c = a+b
print('c =',c.size())
# c = torch.Size([8, 4, 5, 6])

3. 寸法はまったく同じであり、追加することができます。

a = torch.ones([8, 4, 5, 6])
print('a =',a.size())
b = torch.ones([8, 4, 5, 6])
print('b =',b.size())
c = a+b
print('c =',c.size())
# c = torch.Size([8, 4, 5, 6])

追加できない例:

1. 4 は 2 に等しくないため、加算できません

a = torch.ones([8, 4, 5, 6])
print('a =',a.size())
b = torch.ones([1, 2, 1, 6])
print('b =',b.size())
c = a+b
print('c =',c.size())

2. 3 は 6 に等しくないため、加算できません

a = torch.ones([8, 4, 5, 6])
print('a =',a.size())
b = torch.ones([1, 4, 1, 3])
print('b =',b.size())
c = a+b
print('c =',c.size())

おすすめ

転載: blog.csdn.net/ytusdc/article/details/128106265