Torch.mean()寸法の説明

a = torch.arange(6) * 1.
# print(a)
a = a.reshape(2, 1, 3)
print(a)
print(a.shape)
b = a.mean(dim=0)
print(b)
print(b.shape)

--------------------------------------------
res:
tensor([[[0., 1., 2.]],

        [[3., 4., 5.]]])
torch.Size([2, 1, 3])
tensor([[1.5000, 2.5000, 3.5000]])
torch.Size([1, 3])

三次元(m、n、q):

ls = [

[[1,2]、[3,4]]、

[[5,6]、[7,8]]

]

ls.shape = 2 * 2 * 2

dim = 0、

固定の行と列を追加します。

(1 + 5)/ 2 = 3、

(2 + 6)/ 2 = 4、

(3 + 7)/ 2 = 5、

(4 + 8)/ 2 = 6、

ls.mean =

[

[3,4]、[5,6]

]  

ls.shape =(2 * 2)

---------------------------------------------

dim = 1

固定列、行の追加

(1+ 3)/ 2 = 2、

(2 + 4)/ 2 = 3、

(5 + 7)/ 2 = 6、

(6 + 8)/ 2 = 7、

ls.mean(dim = 1)=

[

[2,3]、

[6,7]

]

ls.shape =(2 * 2)  

a = torch.arange(8) * 1.
# print(a)
a = a.reshape(2, 2, 2)
print(a)
print(a.shape)
b = a.mean(dim=1)
print(b)
print(b.shape)
-------------------------------

tensor([[[0., 1.],
         [2., 3.]],

        [[4., 5.],
         [6., 7.]]])
torch.Size([2, 2, 2])
tensor([[1., 2.],
        [5., 6.]])
torch.Size([2, 2])

 

おすすめ

転載: blog.csdn.net/weixin_40823740/article/details/115252326