Get indices of elements in tensor a that are present in tensor b

Douglas De Rizzo Meneghetti :

For example, I want to get the indices of elements valued 0 and 2 in tensor a. These values, (0 and 2) are stored in tensor b. I have devised a pythonic way to do so (shown below), but I don't think list comprehensions are optimized to run on GPU, or maybe there is a more PyTorchy way to do it that I am unaware of.

import torch
a = torch.tensor([0, 1, 0, 1, 1, 0, 2])
b = torch.tensor([0, 2])
torch.tensor([x in b for x in a]).nonzero()

>>>> tensor([[0],
             [2],
             [5],
             [6]])

Any other suggestions or is this an acceptable way?

Andreas K. :

Here's a more efficient way to do it (as suggested in the link posted by jodag in comments...):

(a[..., None] == b).any(-1).nonzero()

Guess you like

Origin http://10.200.1.11:23101/article/api/json?id=387093&siteId=1