The torch.nonzero() in pytorch returns the index of the element whose element is not 0 in the tensor.
Examples are as follows:
import torch
x = torch.tensor([4,0,1,2,1,2,3])
result = 1==x
print(result)
print(result.nonzero()) #输出了不为0值的索引
print(result.nonzero().view(-1))#将结果转为一维的张量