Análisis de parámetros de dimensión de Torch.sum

a = torch.ones(2, 3)
print(a)
print(a.shape)
print(a.sum(dim=0))
print(a.sum(dim=1))
print(a.sum(dim=-1))

#res
tensor([[1., 1., 1.],
        [1., 1., 1.]])

torch.Size([2, 3])
tensor([2., 2., 2.])
tensor([3., 3.])
tensor([3., 3.])

Comprensión personal: para una matriz bidimensional, dim = 0, lo que significa que las filas son fijas y las columnas se agregan; dim = 1, las columnas son fijas y las filas se agregan.

Vale la pena señalar que los resultados de dim = -1 y dim = 2 son los mismos. Especulación personal: el valor de dim es (0,1) como una matriz ls = [0,1], hay ls [-1] = ls [1] = 1 en la sintaxis de Python. El mismo pensamiento.

 

a = torch.ones((2,2,3))
print(a)
print(a.shape)
print(a.sum(dim=-1))
print(a.sum(dim=0))
print(a.sum(dim=1))
print(a.sum(dim=2))

resultado:

a = torch.arange(8) * 1.
# print(a)
a = a.reshape(2, 2, 2)
print(a)
print(a.shape)
b = a.sum(dim=0)
print(b)
print(b.shape)
----------------------------------
res:
tensor([[[0., 1.],
         [2., 3.]],

        [[4., 5.],
         [6., 7.]]])
torch.Size([2, 2, 2])
tensor([[ 4.,  6.],
        [ 8., 10.]])
torch.Size([2, 2])
a = torch.arange(8) * 1.
# print(a)
a = a.reshape(2, 2, 2)
print(a)
print(a.shape)
b = a.sum(dim=1)
print(b)
print(b.shape)

--------------------------------------
tensor([[[0., 1.],
         [2., 3.]],

        [[4., 5.],
         [6., 7.]]])
torch.Size([2, 2, 2])
tensor([[ 2.,  4.],
        [10., 12.]])
torch.Size([2, 2])

 

Matriz tridimensional: la idea es la misma que la anterior, dim = (0,1,2), dim = 1, fila fija, adición de columna; dim = 2, columna fija, adición de fila. dim = -1 es lo mismo que dim = 2. En cuanto a dim = 0, no está claro cómo agregar.

Supongo que te gusta

Origin blog.csdn.net/weixin_40823740/article/details/114988513
Recomendado
Clasificación