【Pytorch】torch.argmax 函数详解


一、一个参数时的 torch.argmax 函数

官网链接:TORCH.ARGMAX

1. 介绍

torch.argmax(input) → LongTensor

返回输入张量 input 所有元素中的最大值的下标(如果有多个最大值,则返回第一个最大值的索引)。

2. 实例

import torch
a = torch.randn(4, 4)
print(a)
print(torch.argmax(a))

输出结果:

tensor([[-0.7018,  1.1887, -0.2344,  0.3216],
        [ 1.3548, -0.8575, -1.0585, -0.3462],
        [ 0.5845,  0.2345,  1.6444,  1.1129],
        [-1.1226, -0.5765, -0.4906,  0.0132]])
tensor(10)

在所有的元素中,第11个元素 1.6444 最大,其索引是 10 ,因此返回 tensor(10)。


二、多个参数时的 torch.argmax 函数

1. 介绍

torch.argmax(input, dim=None, keepdim=False)
  • 返回一个张量 input 在某一维度 dim 上的最大值的索引(返回 input 的指定维度 dim 上的最大值的序号)。
  • input (Tensor) - 输入张量。
  • dim (int) - 要减少的维度(指定维度)。如果为None,则返回扁平输入的argmax。dim 的不同值表示不同维度。在二维中,dim=0 表示行,此时要压缩行,找列的最大值;dim=1 表示列,此时要压缩列,找行的最大值。广泛的来说,我们不管一个矩阵是几维的,比如一个矩阵维度如下:(d0,d1,…,dn−1) ,那么 dim=0 就表示对应到d0 也就是第一个维度,dim=1表示对应到也就是第二个维度,依次类推。指定哪个维度,哪个维度就要消失,就要被压缩。
  • Keepdim (bool) - 输出张量是否保留dim。如果dim=None 则忽略。
  • 返回值:指定维度 dim 消失之后的矩阵。dim (int) – the dimension to reduce。因为在该维度找了最大值,相当于该维度就被压缩了,只保留了其他维度。这样不好理解,接下来看看例子。

2. 实例

实例1:二维矩阵

import torch
a = torch.tensor(
              [
                  [1, 5, 5, 2],
                  [9, -6, 2, 8],
                  [-3, 7, -9, 1]
              ])
print(a.shape)
b = torch.argmax(a, dim=0)  # 压缩行,返回列最大值的序号
print(b)
print(b.shape)

输出结果:

torch.Size([3, 4])
tensor([1, 2, 0, 1])
torch.Size([4])

指定的维度是 0 ,也就是行,要压缩行,就要找列的最大值。
从 [3, 4] -> [4],可见第一个维度 3 消失了。

import torch

a = torch.randn(4, 4)
print(a)
print(torch.argmax(a, dim=1)) #压缩列,返回行最大值的序号

输出结果:

tensor([[-1.3736,  0.8958, -0.6470,  1.3395],
        [-0.4279,  0.0682,  0.7635,  1.1857],
        [ 1.7861, -0.6515, -0.5456, -0.3066],
        [ 1.1898, -0.0208, -0.3662,  0.1799]])
tensor([3, 3, 0, 0])

指定的维度是 1 ,也就是列,要压缩列,就要找行的最大值。

实例2:三维矩阵

import torch

a = torch.tensor([
    [
        [1, 5, 5, 2],
        [9, -6, 2, 8],
        [-3, 7, -9, 1]
    ],

    [
        [-1, 7, -5, 2],
        [9, 6, 2, 8],
        [3, 7, 9, 1]
    ]])

print(a.shape)
b = torch.argmax(a, dim=0)
print(b)
print(b.shape)

输出结果:

扫描二维码关注公众号,回复: 14577398 查看本文章
torch.Size([2, 3, 4])
tensor([[0, 1, 0, 0],
        [0, 1, 0, 0],
        [1, 0, 1, 0]])
torch.Size([3, 4])

从 [2, 3, 4] -> [3, 4],可见第一个维度 2 消失了。

实例3:保留dim

import torch

a = torch.tensor([
    [
        [1, 5, 5, 2],
        [9, -6, 2, 8],
        [-3, 7, -9, 1]
    ],

    [
        [-1, 7, -5, 2],
        [9, 6, 2, 8],
        [3, 7, 9, 1]
    ]])

print(a.shape)
b = torch.argmax(a, dim=0, keepdim=True)
print(b)
print(b.shape)

输出结果:

torch.Size([2, 3, 4])
tensor([[[0, 1, 0, 0],
         [0, 1, 0, 0],
         [1, 0, 1, 0]]])
torch.Size([1, 3, 4])

与实例2的不同之处:加了 keepdim=True 参数,输出从 [3, 4] -> [1, 3, 4],保留了被压缩的第一维,只不过从 2 变成了压缩后的 1 。


参考链接

  1. 【搞懂PyTorch】torch.argmax() 函数详解

猜你喜欢

转载自blog.csdn.net/weixin_44211968/article/details/128216020
今日推荐