torch.topk

1维的可以直接取值,

    import torch

    anch_ious = torch.Tensor([[1, 2, 3], [4, 5, 6]]).view(-1)


    neg_count=4
    top_data,index= torch.topk(anch_ious, neg_count, dim=0, largest=True, sorted=True, out=None)

    print(anch_ious[index])

2维以上就行了:

能返回index,但是不能根据index获取到值

需要根据指定维度取值,用gather

    import torch

    anch_ious = torch.Tensor([[1, 2, 3], [4, 5, 6]])

    neg_count=2
    top_data,index= torch.topk(anch_ious, neg_count, dim=1, largest=True, sorted=True, out=None)

    print(top_data)

    # b = torch.LongTensor([0, 1]).view(2, 1)

    c = torch.gather(input=anch_ious, dim=1, index=index)
    print(c)
发布了2732 篇原创文章 · 获赞 1011 · 访问量 538万+

猜你喜欢

转载自blog.csdn.net/jacke121/article/details/104771822