pytorch张量索引

一、pytorch返回最值索引

1 官方文档资料



1.1 torch.argmax()介绍

 返回最大值的索引下标

函数:
     torch.argmax(input, dim, keepdim=False) → LongTensor

返回值:
     Returns the indices of the maximum values of a tensor across a dimension.

参数:
	input (Tensor) – the input tensor.
	dim (int) – the dimension to reduce. If None, the argmax of the flattened input is returned.
	keepdim (bool) – whether the output tensor has dim retained or not. Ignored if dim=None.

1.2 torch.argmin()介绍 

 返回最小值的索引下标

函数:
     torch.argmin(input, dim, keepdim=False) → LongTensor

返回值:
     Returns the indices of the mimimum values of a tensor across a dimension.

参数:
	input (Tensor) – the input tensor.
	dim (int) – the dimension to reduce. If None, the argmax of the flattened input is returned.
	keepdim (bool) – whether the output tensor has dim retained or not. Ignored if dim=None.

2 代码示例



2.1 torch.argmax()代码示例

>>> import torch
>>> Matrix = torch.randn(2,2,2)
>>> print(Matrix)
tensor([[[ 0.3772, -0.1143],
         [ 0.2217, -0.1897]],

        [[ 0.1488, -0.8758],
         [ 1.7734, -0.5929]]])
>>> print(Matrix.argmax(dim=0))
tensor([[0, 0],
        [1, 0]])
>>> print(Matrix.argmax(dim=1))
tensor([[0, 0],
        [1, 1]])
>>> print(Matrix.argmax(dim=2))
tensor([[0, 0],
        [0, 0]])
>>> print(Matrix.argmax())
tensor(6)

2.2 torch.argmin()代码示

>>> import torch
>>> Matrix = torch.randn(2,2,2)
>>> print(Matrix)
tensor([[[ 0.5821,  0.2889],
         [ 0.4669, -0.3135]],

        [[-0.4567,  0.2975],
         [-1.5084,  0.7320]]])
>>> print(Matrix.argmin(dim=0))
tensor([[1, 0],
        [1, 0]])
>>> print(Matrix.argmin(dim=1))
tensor([[1, 1],
        [1, 0]])
>>> print(Matrix.argmin(dim=2))
tensor([[1, 1],
        [0, 0]])
>>> print(Matrix.argmin())
tensor(6)

 二、pytorch返回任意值索引

tens = tensor([[  101,   146,  1176, 21806,  1116,  1105, 18621,   119,   102,     0,
             0,     0,     0],
        [  101,  1192,  1132,  1136,  1184,   146,  1354,  1128,  1127,   117,
          1463,   119,   102],
        [  101,  6816,  1905,  1132, 14918,   119,   102,     0,     0,     0,
             0,     0,     0]])
idxs = torch.tensor([(i == 101).nonzero() for i in tens])

from torch import tensor                                                                       

tens = torch.tensor([[  101,   146,  1176, 21806,  1116,  1105, 18621,   119,   102,     0, 
    ...:              0,     0,     0], 
    ...:         [  101,  1192,  1132,  1136,  1184,   146,  1354,  1128,  1127,   117, 
    ...:           1463,   119,   102], 
    ...:         [  101,  6816,  1905,  1132, 14918,   119,   102,     0,     0,     0, 
    ...:              0,     0,     0]])                                                                

(tens == 101).nonzero()[:, 1]                                                                  
tensor([0, 0, 0])

三、pytorch 只保留tensor的最大值或最小值,其他位置置零

如下,x是输入张量,dim指定维度,max可以替换成min 


import torch

if __name__ == '__main__':
    
    x = torch.randn([1, 3, 4, 4]).cuda()

    mask = (x == x.max(dim=1, keepdim=True)[0]).to(dtype=torch.int32)
    result = torch.mul(mask, x)

    print(x)
    print(mask)
    print(result)

输出效果:

tensor([[[[-0.8807,  0.1029,  0.0184,  1.2695],
          [-0.0934,  1.0650, -0.2927,  0.0049],
          [ 0.2338, -1.8663,  1.2763,  0.7248],
          [-1.5138,  0.6834,  0.1463,  0.0650]],

         [[ 0.5020,  1.6078, -0.0104,  1.2042],
          [ 1.8859, -0.4682, -0.1177,  0.5197],
          [ 1.7649,  0.4585,  0.6002,  0.3350],
          [-1.1384, -0.0325,  0.8490,  0.6080]],

         [[-0.5618,  0.5388, -0.0572, -0.7240],
          [-0.3458,  1.3494, -0.0603, -1.1562],
          [-0.3652,  1.1885,  1.6293,  0.4134],
          [ 1.3009,  1.2027, -0.8711,  1.3321]]]], device='cuda:0')
tensor([[[[0, 0, 1, 1],
          [0, 0, 0, 0],
          [0, 0, 0, 1],
          [0, 0, 0, 0]],

         [[1, 1, 0, 0],
          [1, 0, 0, 1],
          [1, 0, 0, 0],
          [0, 0, 1, 0]],

         [[0, 0, 0, 0],
          [0, 1, 1, 0],
          [0, 1, 1, 0],
          [1, 1, 0, 1]]]], device='cuda:0', dtype=torch.int32)
tensor([[[[-0.0000,  0.0000,  0.0184,  1.2695],
          [-0.0000,  0.0000, -0.0000,  0.0000],
          [ 0.0000, -0.0000,  0.0000,  0.7248],
          [-0.0000,  0.0000,  0.0000,  0.0000]],

         [[ 0.5020,  1.6078, -0.0000,  0.0000],
          [ 1.8859, -0.0000, -0.0000,  0.5197],
          [ 1.7649,  0.0000,  0.0000,  0.0000],
          [-0.0000, -0.0000,  0.8490,  0.0000]],

         [[-0.0000,  0.0000, -0.0000, -0.0000],
          [-0.0000,  1.3494, -0.0603, -0.0000],
          [-0.0000,  1.1885,  1.6293,  0.0000],
          [ 1.3009,  1.2027, -0.0000,  1.3321]]]], device='cuda:0')

Process finished with exit code 0

 四、使用pytorch获取tensor每行中的top k

    ???老铁,这么简单的问题还有问,自己解决去!!!!

参考:

python - How Pytorch Tensor get the index of specific value - Stack Overflowhttps://stackoverflow.com/questions/47863001/how-pytorch-tensor-get-the-index-of-specific-value

How Pytorch Tensor get the index of elements?https://stackoverflow.com/questions/57933781/how-pytorch-tensor-get-the-index-of-elements

https://discuss.pytorch.org/t/keep-the-max-value-of-the-array-and-0-the-others/14480/8icon-default.png?t=LA92https://discuss.pytorch.org/t/keep-the-max-value-of-the-array-and-0-the-others/14480/8 https://blog.csdn.net/bxdzyhx/article/details/120252197icon-default.png?t=LA92https://blog.csdn.net/bxdzyhx/article/details/120252197

 

Guess you like

Origin blog.csdn.net/l641208111/article/details/121548272
Recommended