Uso simple de Pytorch torch.mean()

En pocas palabras, es un promedio .
Por ejemplo, los siguientes tres casos simples:

import torch

x1 = torch.Tensor([1, 2, 3, 4])
x2 = torch.Tensor([[1],
                   [2],
                   [3],
                   [4]])
x3 = torch.Tensor([[1, 2],
                   [3, 4]])
y1 = torch.mean(x1)
y2 = torch.mean(x2)
y3 = torch.mean(x3)
print(y1)
print(y2)
print(y3)

producción:

tensor(2.5000)
tensor(2.5000)
tensor(2.5000)

Es decir, cuando no se especifica ninguna dimensión, se promedian todos los números.

Más a menudo, se utiliza el caso dimensional, como:

import torch

x = torch.Tensor([1, 2, 3, 4, 5, 6]).view(2, 3)
y_0 = torch.mean(x, dim=0)
y_1 = torch.mean(x, dim=1)
print(x)
print(y_0)
print(y_1)

producción:

tensor([[1., 2., 3.],
        [4., 5., 6.]])
tensor([2.5000, 3.5000, 4.5000])
tensor([2., 5.])

La forma del tensor de entrada es (2, 3), donde 2 es la dimensión 0 y 3 es la dimensión 1. Promediar la 0ª dimensión produce un tensor de forma (1, 3); promediar la 1ª dimensión produce un tensor de forma (2, 1).
Se puede entender que el promedio de qué dimensión es promediar todos los números de la dimensión y aplanarlos en 1 capa (de hecho, esta capa se fusiona, como en el ejemplo anterior, el tensor bidimensional está promediando Después de contar, se vuelve unidimensional), y la forma de otras dimensiones no afecta.
Si desea mantener la dimensionalidad constante (como en las redes profundas), puede agregar parámetros keepdim=True:

y = torch.mean(x, dim=1, keepdim=True)

Supongo que te gusta

Origin blog.csdn.net/qq_40714949/article/details/115485140
Recomendado
Clasificación