a = torch.ones(2, 3)
print(a)
print(a.shape)
print(a.sum(dim=0))
print(a.sum(dim=1))
print(a.sum(dim=-1))
#res
tensor([[1., 1., 1.],
[1., 1., 1.]])
torch.Size([2, 3])
tensor([2., 2., 2.])
tensor([3., 3.])
tensor([3., 3.])
Comprensión personal: para una matriz bidimensional, dim = 0, lo que significa que las filas son fijas y las columnas se agregan; dim = 1, las columnas son fijas y las filas se agregan.
Vale la pena señalar que los resultados de dim = -1 y dim = 2 son los mismos. Especulación personal: el valor de dim es (0,1) como una matriz ls = [0,1], hay ls [-1] = ls [1] = 1 en la sintaxis de Python. El mismo pensamiento.
a = torch.ones((2,2,3))
print(a)
print(a.shape)
print(a.sum(dim=-1))
print(a.sum(dim=0))
print(a.sum(dim=1))
print(a.sum(dim=2))
resultado:
a = torch.arange(8) * 1.
# print(a)
a = a.reshape(2, 2, 2)
print(a)
print(a.shape)
b = a.sum(dim=0)
print(b)
print(b.shape)
----------------------------------
res:
tensor([[[0., 1.],
[2., 3.]],
[[4., 5.],
[6., 7.]]])
torch.Size([2, 2, 2])
tensor([[ 4., 6.],
[ 8., 10.]])
torch.Size([2, 2])
a = torch.arange(8) * 1.
# print(a)
a = a.reshape(2, 2, 2)
print(a)
print(a.shape)
b = a.sum(dim=1)
print(b)
print(b.shape)
--------------------------------------
tensor([[[0., 1.],
[2., 3.]],
[[4., 5.],
[6., 7.]]])
torch.Size([2, 2, 2])
tensor([[ 2., 4.],
[10., 12.]])
torch.Size([2, 2])