torch中的topk()函数
In [2]: import torch
In [3]: a=torch.randn((4,6))
In [4]: a
Out[4]:
tensor([[-0.5215, 1.3219, -2.0798, -2.3303, -0.3767, -0.8851],
[-1.6861, -1.7882, 0.2139, -0.4486, -0.3331, -2.1024],
[-0.5967, -0.5847, 0.0688, -0.9874, -1.4095, 0.2968],
[ 0.1683, 1.0879, -0.8358, 0.2569, 1.1068, -0.2029]])
In [5]: _,pred = a.topk(1)
In [6]: _
Out[6]:
tensor([[1.3219],
[0.2139],
[0.2968],
[1.1068]])
In [7]: pred
Out[7]:
tensor([[1],
[2],
[5],
[4]])