a = torch.arange(6) * 1.
# print(a)
a = a.reshape(2, 1, 3)
print(a)
print(a.shape)
b = a.mean(dim=0)
print(b)
print(b.shape)
--------------------------------------------
res:
tensor([[[0., 1., 2.]],
[[3., 4., 5.]]])
torch.Size([2, 1, 3])
tensor([[1.5000, 2.5000, 3.5000]])
torch.Size([1, 3])
三次元(m、n、q):
ls = [
[[1,2]、[3,4]]、
[[5,6]、[7,8]]
]
ls.shape = 2 * 2 * 2
dim = 0、
固定の行と列を追加します。
(1 + 5)/ 2 = 3、
(2 + 6)/ 2 = 4、
(3 + 7)/ 2 = 5、
(4 + 8)/ 2 = 6、
ls.mean =
[
[3,4]、[5,6]
]
ls.shape =(2 * 2)
---------------------------------------------
dim = 1
固定列、行の追加
(1+ 3)/ 2 = 2、
(2 + 4)/ 2 = 3、
(5 + 7)/ 2 = 6、
(6 + 8)/ 2 = 7、
ls.mean(dim = 1)=
[
[2,3]、
[6,7]
]
ls.shape =(2 * 2)
a = torch.arange(8) * 1.
# print(a)
a = a.reshape(2, 2, 2)
print(a)
print(a.shape)
b = a.mean(dim=1)
print(b)
print(b.shape)
-------------------------------
tensor([[[0., 1.],
[2., 3.]],
[[4., 5.],
[6., 7.]]])
torch.Size([2, 2, 2])
tensor([[1., 2.],
[5., 6.]])
torch.Size([2, 2])