Analyse des paramètres de dimension Torch.sum

a = torch.ones(2, 3)
print(a)
print(a.shape)
print(a.sum(dim=0))
print(a.sum(dim=1))
print(a.sum(dim=-1))

#res
tensor([[1., 1., 1.],
        [1., 1., 1.]])

torch.Size([2, 3])
tensor([2., 2., 2.])
tensor([3., 3.])
tensor([3., 3.])

Compréhension personnelle: pour un tableau à deux dimensions, dim = 0, ce qui signifie que les lignes sont fixes et que les colonnes sont ajoutées; dim = 1, les colonnes sont fixes et les lignes sont ajoutées.

Il est à noter que les résultats de dim = -1 et dim = 2 sont les mêmes. Spéculation personnelle: la valeur de dim est (0,1) tout comme un tableau ls = [0,1], il y a ls [-1] = ls [1] = 1 dans la syntaxe python. La même pensée.

 

a = torch.ones((2,2,3))
print(a)
print(a.shape)
print(a.sum(dim=-1))
print(a.sum(dim=0))
print(a.sum(dim=1))
print(a.sum(dim=2))

résultat:

a = torch.arange(8) * 1.
# print(a)
a = a.reshape(2, 2, 2)
print(a)
print(a.shape)
b = a.sum(dim=0)
print(b)
print(b.shape)
----------------------------------
res:
tensor([[[0., 1.],
         [2., 3.]],

        [[4., 5.],
         [6., 7.]]])
torch.Size([2, 2, 2])
tensor([[ 4.,  6.],
        [ 8., 10.]])
torch.Size([2, 2])
a = torch.arange(8) * 1.
# print(a)
a = a.reshape(2, 2, 2)
print(a)
print(a.shape)
b = a.sum(dim=1)
print(b)
print(b.shape)

--------------------------------------
tensor([[[0., 1.],
         [2., 3.]],

        [[4., 5.],
         [6., 7.]]])
torch.Size([2, 2, 2])
tensor([[ 2.,  4.],
        [10., 12.]])
torch.Size([2, 2])

 

Tableau tridimensionnel: L'idée est la même que ci-dessus, dim = (0,1,2), dim = 1, ligne fixe, ajout de colonne; dim = 2, colonne fixe, ajout de ligne. dim = -1 est identique à dim = 2. Quant à dim = 0, il n'est pas clair comment ajouter.

Je suppose que tu aimes

Origine blog.csdn.net/weixin_40823740/article/details/114988513
conseillé
Classement