Simple usage of Pytorch torch.mean()

Simply put, it's an average .
For example, the following three simple cases:

import torch

x1 = torch.Tensor([1, 2, 3, 4])
x2 = torch.Tensor([[1],
                   [2],
                   [3],
                   [4]])
x3 = torch.Tensor([[1, 2],
                   [3, 4]])
y1 = torch.mean(x1)
y2 = torch.mean(x2)
y3 = torch.mean(x3)
print(y1)
print(y2)
print(y3)

output:

tensor(2.5000)
tensor(2.5000)
tensor(2.5000)

That is, when no dimension is specified, all numbers are averaged.

More often, the dimensional case is used, such as:

import torch

x = torch.Tensor([1, 2, 3, 4, 5, 6]).view(2, 3)
y_0 = torch.mean(x, dim=0)
y_1 = torch.mean(x, dim=1)
print(x)
print(y_0)
print(y_1)

output:

tensor([[1., 2., 3.],
        [4., 5., 6.]])
tensor([2.5000, 3.5000, 4.5000])
tensor([2., 5.])

The shape of the input tensor is (2, 3), where 2 is the 0th dimension and 3 is the 1st dimension. Averaging the 0th dimension yields a tensor of shape (1, 3); averaging the 1st dimension yields a tensor of shape (2, 1).
It can be understood that the average of which dimension is to average all the numbers of the dimension and flatten it into 1 layer (in fact, this layer is merged, such as the above example, the 2-dimensional tensor is averaging After counting, it becomes 1-dimensional), and the shape of other dimensions does not affect.
If you want to keep the dimensionality constant (like in deep networks), you can add parameters keepdim=True:

y = torch.mean(x, dim=1, keepdim=True)

Guess you like

Origin blog.csdn.net/qq_40714949/article/details/115485140