Douglas De Rizzo Meneghetti :
For example, I want to get the indices of elements valued 0 and 2 in tensor a
. These values, (0 and 2) are stored in tensor b
. I have devised a pythonic way to do so (shown below), but I don't think list comprehensions are optimized to run on GPU, or maybe there is a more PyTorchy way to do it that I am unaware of.
import torch
a = torch.tensor([0, 1, 0, 1, 1, 0, 2])
b = torch.tensor([0, 2])
torch.tensor([x in b for x in a]).nonzero()
>>>> tensor([[0],
[2],
[5],
[6]])
Any other suggestions or is this an acceptable way?
Andreas K. :
Here's a more efficient way to do it (as suggested in the link posted by jodag in comments...):
(a[..., None] == b).any(-1).nonzero()
Guess you like
Origin http://10.200.1.11:23101/article/api/json?id=387093&siteId=1