Torch.sumディメンションパラメータ分析

a = torch.ones(2, 3)
print(a)
print(a.shape)
print(a.sum(dim=0))
print(a.sum(dim=1))
print(a.sum(dim=-1))

#res
tensor([[1., 1., 1.],
        [1., 1., 1.]])

torch.Size([2, 3])
tensor([2., 2., 2.])
tensor([3., 3.])
tensor([3., 3.])

個人的な理解:2次元配列の場合、dim = 0は、行が固定され、列が追加されることを意味します。dim= 1、列は固定され、行が追加されます。

dim = -1とdim = 2の結果が同じであることは注目に値します。個人的な推測:dimの値は配列と同じように(0,1)です。ls= [0,1]、Python構文にはls [-1] = ls [1] = 1があります。同じ考え。

 

a = torch.ones((2,2,3))
print(a)
print(a.shape)
print(a.sum(dim=-1))
print(a.sum(dim=0))
print(a.sum(dim=1))
print(a.sum(dim=2))

結果:

a = torch.arange(8) * 1.
# print(a)
a = a.reshape(2, 2, 2)
print(a)
print(a.shape)
b = a.sum(dim=0)
print(b)
print(b.shape)
----------------------------------
res:
tensor([[[0., 1.],
         [2., 3.]],

        [[4., 5.],
         [6., 7.]]])
torch.Size([2, 2, 2])
tensor([[ 4.,  6.],
        [ 8., 10.]])
torch.Size([2, 2])
a = torch.arange(8) * 1.
# print(a)
a = a.reshape(2, 2, 2)
print(a)
print(a.shape)
b = a.sum(dim=1)
print(b)
print(b.shape)

--------------------------------------
tensor([[[0., 1.],
         [2., 3.]],

        [[4., 5.],
         [6., 7.]]])
torch.Size([2, 2, 2])
tensor([[ 2.,  4.],
        [10., 12.]])
torch.Size([2, 2])

 

3次元配列:考え方は上記と同じで、dim =(0,1,2)、dim = 1、固定行、列の追加、dim = 2、固定列、行の追加です。dim = -1はdim = 2と同じです。dim = 0に関しては、追加方法が明確ではありません。

おすすめ

転載: blog.csdn.net/weixin_40823740/article/details/114988513