pytorch-高阶OP

在这里插入图片描述
在这里插入图片描述

Tensor advanced operation

▪ Where

▪ Gather

Where

import torch
cond = torch.randn(2,2)
cond
tensor([[ 2.3397,  1.2282],
        [-1.7580,  0.9515]])
a = torch.ones(2,2)
a
tensor([[1., 1.],
        [1., 1.]])
b = torch.zeros(2,2)
b
tensor([[0., 0.],
        [0., 0.]])
result = torch.where(cond > 0, a, b)
result
tensor([[1., 1.],
        [0., 1.]])

Gather

prob = torch.randn(4,10)
prob
tensor([[ 1.6869, -0.8115, -0.2504, -0.1094,  0.7893,  0.4076, -0.0311, -0.6119,
         -0.1706, -0.6414],
        [ 0.6565,  0.0926, -0.4102,  1.0853, -0.3517, -0.4270,  0.2243, -0.5348,
          0.6137, -0.4874],
        [-0.6740, -2.5947, -0.3836, -1.7365, -1.7719,  0.2721,  0.7912,  0.2159,
          1.1237, -0.7022],
        [ 0.3968, -0.3211,  2.3566, -0.0133,  1.5263,  0.6008, -0.7640,  0.2766,
          0.3015, -0.3570]])
idx = prob.topk(dim=1, k=3)  # 取top-3
idx 
torch.return_types.topk(
values=tensor([[1.6869, 0.7893, 0.4076],
        [1.0853, 0.6565, 0.6137],
        [1.1237, 0.7912, 0.2721],
        [2.3566, 1.5263, 0.6008]]),
indices=tensor([[0, 4, 5],
        [3, 0, 8],
        [8, 6, 5],
        [2, 4, 5]]))
idx1 = idx[0]
print(idx1)
idx2 = idx[1]
print(idx2)
tensor([[1.6869, 0.7893, 0.4076],
        [1.0853, 0.6565, 0.6137],
        [1.1237, 0.7912, 0.2721],
        [2.3566, 1.5263, 0.6008]])
tensor([[0, 4, 5],
        [3, 0, 8],
        [8, 6, 5],
        [2, 4, 5]])
label = torch.arange(10)+100
label
tensor([100, 101, 102, 103, 104, 105, 106, 107, 108, 109])
res1 = torch.gather(label.expand(4,10), dim=1, index=idx2.long())
res1
tensor([[100, 104, 105],
        [103, 100, 108],
        [108, 106, 105],
        [102, 104, 105]])
label.expand(4,10)
tensor([[100, 101, 102, 103, 104, 105, 106, 107, 108, 109],
        [100, 101, 102, 103, 104, 105, 106, 107, 108, 109],
        [100, 101, 102, 103, 104, 105, 106, 107, 108, 109],
        [100, 101, 102, 103, 104, 105, 106, 107, 108, 109]])
index=idx2.long()
index
tensor([[0, 4, 5],
        [3, 0, 8],
        [8, 6, 5],
        [2, 4, 5]])
index=idx2
index
tensor([[0, 4, 5],
        [3, 0, 8],
        [8, 6, 5],
        [2, 4, 5]])

参考:
知乎-gather

简书-gather

猜你喜欢

转载自blog.csdn.net/MasterCayman/article/details/109408389