【PyTorch】torch.mean(), dim=0, dim=1 详解

创建一个tensor,这个tensor是一个元素类型为浮点型的2维数组

import torch
s = torch.arange(6,dtype=float).reshape((2,3))
print(s)
print(s.shape)# 查看tensor的形状
tensor([[0., 1., 2.],
        [3., 4., 5.]], dtype=torch.float64)
torch.Size([2, 3])

dim属性的全称是dimension,表示维度。dim=0为第0个维度,代表行。

对于torch.mean(s,dim=0),表示跨行求平均。

得到的结果是一个向量,分别对应于 1.5=(0.0+3.0)/2, 2.5=(1.0+4.0)/2, 3.5=(2.0+5.0)/2

s1 = torch.mean(s, dim=0)
print(s1)
tensor([1.5000, 2.5000, 3.5000], dtype=torch.float64)

同理,对于dim=1为第一个维度,代表列。

对于torch.sum(s,dim=1),表示跨列求平均。

得到的结果同样是一个向量,分别对应于 1.0=(0.0+1.0+2.0)/3, 4.0=(3.0+4.0+5.0)/3

s2 = torch.mean(s, dim=1)
print(s2)
tensor([1., 4.], dtype=torch.float64)

在上文中我们看到无论是s1或者s2,都是1维数组,而原始数据是一个2维数组。这是因为经过mean操作后,数据自动减少一个维度。

这里我们可以再看一个三维数组的例子,会更明显的看到这种变化。

m = torch.arange(24, dtype=float).reshape((2,3,4))
print(m.shape)
torch.Size([2, 3, 4])

当我们执行dim=0时,数据的形状会变[3,4];

当我们执行dim=1时,数据的形状会变[2,4];

当我们执行dim=2时,数据的形状会变[2,3];

当我们执行dim=[0,2]时,数据的形状会变[3];

m1 = torch.mean(m, dim=0)
m2 = torch.mean(m, dim=1)
m3 = torch.mean(m, dim=2)
m4 = torch.mean(m, dim=[0,2])
print("m1.shape:",m1.shape)
print("m2.shape:",m2.shape)
print("m3.shape:",m3.shape)
print("m4.shape:",m4.shape)
m1.shape: torch.Size([3, 4])
m2.shape: torch.Size([2, 4])
m3.shape: torch.Size([2, 3])
m4.shape: torch.Size([3])

有时我们不希望经过mean操作后的维度发生改变。这时我们可以使用keepdims=True的设置。

可以看到使用keepdims=True后数据仍然是三维数组

a = torch.arange(24, dtype=float).reshape((2,3,4))
print(a.shape)
torch.Size([2, 3, 4])
mean = torch.mean(a, dim=0, keepdims=True)
print(mean.shape)
torch.Size([1, 3, 4])

注意:

  1. 对tensor使用mean操作时,需要转换数据类型,如果使用int型会报错。RuntimeError: Can only calculate the mean of floating types. Got Long instead

  2. 本文介绍的内容,针对tensor.sum()等函数同样适用。

猜你喜欢

转载自blog.csdn.net/u010414589/article/details/115205681