torch.argmax(input, dim, keepdim=False)

导读

最近有时间看一些目标检测项目的代码(基于Pytorch),里边很多Pytorch的相关操作都忘记了,特来此记录一下,用以加深记忆,而且还能以备一样处境的同学前来查询。今天的主角是torch.argmax(input, dim, keepdim=False)。

官方文档地址

https://pytorch.org/docs/stable/generated/torch.argmax.html

torch.argmax(input) → LongTensor

Returns the indices of the maximum value of all elements in the input tensor.

根据官方的解释,该函数可以返回输入张量中所有元素的最大值的索引。当然这只是最初级的用法,根据输入参数的不同,其返回的结果也不同。下面我们一起了解它的参数都有哪些作用。

参数解析

Parameters

  • input (Tensor) : the input tensor.
  • dim (int) :the dimension to reduce. If None, the argmax of the flattened input is returned.
  • keepdim (bool) : whether the output tensor has dim retained or not. Ignored if dim=None.

这是官网上对参数的解释,input就是我们输入的要操作的张量;dim是我们选择的要在张量的哪个维度上进行计算,输出这个维度最大值的索引,这里一行元素变成一个索引,所以官网中用了reduce;keepdim是询问输出是否与输入保持一样的形状,默认是不保持(False)。

举例演示

首先输入一个张量,注意我们输入的这个张量的shape为[5, 9]

>>> x = torch.randn(5,9)
>>> print(x)
tensor([[ 0.3918,  0.3978,  0.2819, -0.8487, -1.0499,  0.2124, -1.3527, -1.5335,
          1.1050],
        [ 0.8450, -0.3717, -0.4705, -0.4024,  2.1019, -0.8545,  1.9085,  0.5792,
         -0.4279],
        [ 0.1993, -0.2887,  0.4467,  0.4878,  1.4934, -1.3862,  0.3576, -0.2363,
         -2.0700],
        [ 0.0536,  0.9385,  1.2661, -0.3469, -0.5772, -0.7822,  0.8315, -1.7256,
         -0.4979],
        [ 1.1592, -0.1604,  0.2798,  0.5974,  0.1782, -2.3354, -1.7775, -0.8366,
          1.8993]])

接下来,现在第一个维度上进行操作

>>> torch.argmax(x,dim=0)
tensor([4, 3, 3, 4, 1, 0, 1, 1, 4])

第一个维度是行,即按行计算,我们看到结果输出的维度为9,正好是输入张量x的列数。torch.argmax()的计算方式如下:
每次在所有行的相同位置取元素,然后计算取得元素集合的最大值索引。
第一次取所有行的第一位元素,x[:, 0], 得到

tensor([0.3918, 0.8450, 0.1993, 0.0536, 1.1592])

第二次取所有行的第二位元素,x[:, 1], 得到

tensor([0.3978, -0.3717, -0.2887, 0.9385, -0.1604])

依次类推,x有9列,我们也可以取9次,所有取的结果如下:

tensor([ 0.3918,  0.8450,  0.1993,  0.0536,  1.1592])
tensor([ 0.3978, -0.3717, -0.2887,  0.9385, -0.1604])
tensor([ 0.2819, -0.4705,  0.4467,  1.2661,  0.2798])
tensor([-0.8487, -0.4024,  0.4878, -0.3469,  0.5974])
tensor([-1.0499,  2.1019,  1.4934, -0.5772,  0.1782])
tensor([-1.3527,  1.9085,  0.3576,  0.8315, -1.7775])
tensor([-1.5335,  0.5792, -0.2363, -1.7256, -0.8366])
tensor([ 1.1050, -0.4279, -2.0700, -0.4979,  1.8993])

然后分别计算以上每个张量中元素的最大值的索引,便得到tensor([4, 3, 3, 4, 1, 0, 1, 1, 4])

同理,按照列来操作也是一样的思路,这里就不详细说了,看结果:

>>> torch.argmax(x,dim=1)
tensor([8, 4, 4, 2, 8])

经过上边例子的演示,我们可以知道torch.argmax(input,dim)可以返回input中dim维度上的最大值索引。
我们给x在目标检测中赋予具体的含义,假如x的形状为[num_bbox, anchor],那么x便是5个预测框分别与9个anchor计算得到的交并比,我们要选出来与预测框交并比最大的那个anchor,用来回归预测框越来越接近GT。这时候就要用到torch.argmax()找到与bbox交并比最大的anchor的序号。

>>> torch.argmax(x,dim=1)
tensor([8, 4, 4, 2, 8])

即与第一个预测框交并比最大的是第9个anchor,与第二个预测框交并比最大的是第5个anchor…

猜你喜欢

转载自blog.csdn.net/Just_do_myself/article/details/123358048