torch.nonzero()

The torch.nonzero() in pytorch returns the index of the element whose element is not 0 in the tensor.

Examples are as follows:

import torch

x = torch.tensor([4,0,1,2,1,2,3])
result = 1==x 
print(result)

print(result.nonzero()) #输出了不为0值的索引
print(result.nonzero().view(-1))#将结果转为一维的张量

 

Guess you like

Origin blog.csdn.net/t20134297/article/details/108236615