Pytorch张量统计运算dim的理解

创建两个维度2x3的张量
>>> input = torch.randn(2,2,3)
>>> input
tensor([[[-0.3491, -0.6243,  0.1242],
         [-1.5700,  1.0065, -1.2502]],

        [[-0.5597,  0.3357,  0.5915],
         [-0.6579, -0.9022,  0.0043]]])

>>> input.size()
torch.Size([2, 2, 3])

**dim=0 表示对每一维度相同位置的数值进行max运算**

>>> input.max(dim=0)
torch.return_types.max(
values=tensor([[-0.3491,  0.3357,  0.5915],
               [-0.6579,  1.0065,  0.0043]]),
indices=tensor([[0, 1, 1],
               [1, 0, 1]]))

**dim=1  表示求每一维的列里进行max运算**

>>> input.max(dim=1)
torch.return_types.max(
values=tensor([[-0.3491,  1.0065,  0.1242],
              [-0.5597,  0.3357,  0.5915]]),
indices=tensor([[0, 1, 0],
               [0, 0, 0]]))

**dim=2 表示求每一维的行里进行max运算**

>>> input.max(dim=2)
torch.return_types.max(
values=tensor([[0.1242, 1.0065],
              [0.5915, 0.0043]]),
indices=tensor([[2, 1],
               [2, 2]]))
发布了66 篇原创文章 · 获赞 1 · 访问量 7014

猜你喜欢

转载自blog.csdn.net/qq_41128383/article/details/105551231