torch中的topk()函数

torch中的topk()函数

In [2]: import torch

In [3]:  a=torch.randn((4,6))

In [4]: a
Out[4]: 
tensor([[-0.5215,  1.3219, -2.0798, -2.3303, -0.3767, -0.8851],
        [-1.6861, -1.7882,  0.2139, -0.4486, -0.3331, -2.1024],
        [-0.5967, -0.5847,  0.0688, -0.9874, -1.4095,  0.2968],
        [ 0.1683,  1.0879, -0.8358,  0.2569,  1.1068, -0.2029]])

In [5]: _,pred = a.topk(1)

In [6]: _
Out[6]: 
tensor([[1.3219],
        [0.2139],
        [0.2968],
        [1.1068]])

In [7]: pred
Out[7]: 
tensor([[1],
        [2],
        [5],
        [4]])
发布了790 篇原创文章 · 获赞 100 · 访问量 14万+

猜你喜欢

转载自blog.csdn.net/weixin_44510615/article/details/104414604