Some usage summary of functions in torch (continuously updated)

torch.gather

b = torch.Tensor([[1,2,3],[4,5,6]])
print(b)
index_1 = torch.LongTensor([[0,1],[2,0]])
index_2 = torch.LongTensor([[0,1,1],[0,0,0]])
print(torch.gather(b, dim=1, index=index_1))
print(torch.gather(b, dim=0, index=index_2))
b:
tensor([[1., 2., 3.],
        [4., 5., 6.]])
index_1:
tensor([[0, 1],
        [2, 0]]) 
        
index_2:
tensor([[0, 1, 1],
        [0, 0, 0]])
print(torch.gather(b, dim=1, index=index_1))
tensor([[1., 2.],
        [6., 4.]])

print(torch.gather(b, dim=0, index=index_2))     
tensor([[1., 5., 6.],
        [1., 2., 3.]])
  • If dim = 1: The element in index_1 is the index corresponding to the row element in b.
  • If dim = 2: the first column element (0, 0) of index_2 corresponds to the first column in b, so take out the (1, 1) element, and the second column element (1, 0) of index_2 corresponds to the second column in b, So take out (5, 2) elements, the third column element (1, 0) of index_2 corresponds to the first column in b, so take out (6, 3) elements.

torch.eq(predict_ labels, labels)

In the formula, predict_ labels and labels are two tensors of the same size, and the torch.eq() function is used to compare the corresponding position numbers, if the same is True, otherwise it is False, the output is the same size as the two tensors, and only True and False.

predict_labels = torch.LongTensor([0,1, 2, 3 ,4])
labels =  torch.LongTensor([4,3,2,1,4])

torch.eq(predict_labels, labels)
Out[15]: tensor([False, False,  True, False,  True])

torch.eq(predict_labels, labels).sum()
Out[16]: tensor(2)

torch.eq(predict_labels, labels).sum().item()
Out[17]: 2

Guess you like

Origin blog.csdn.net/weixin_54546190/article/details/126585016