Tensor advanced operation
▪ Where
▪ Gather
Where
import torch
cond = torch.randn(2,2)
cond
tensor([[ 2.3397, 1.2282],
[-1.7580, 0.9515]])
a = torch.ones(2,2)
a
tensor([[1., 1.],
[1., 1.]])
b = torch.zeros(2,2)
b
tensor([[0., 0.],
[0., 0.]])
result = torch.where(cond > 0, a, b)
result
tensor([[1., 1.],
[0., 1.]])
Gather
prob = torch.randn(4,10)
prob
tensor([[ 1.6869, -0.8115, -0.2504, -0.1094, 0.7893, 0.4076, -0.0311, -0.6119,
-0.1706, -0.6414],
[ 0.6565, 0.0926, -0.4102, 1.0853, -0.3517, -0.4270, 0.2243, -0.5348,
0.6137, -0.4874],
[-0.6740, -2.5947, -0.3836, -1.7365, -1.7719, 0.2721, 0.7912, 0.2159,
1.1237, -0.7022],
[ 0.3968, -0.3211, 2.3566, -0.0133, 1.5263, 0.6008, -0.7640, 0.2766,
0.3015, -0.3570]])
idx = prob.topk(dim=1, k=3) # 取top-3
idx
torch.return_types.topk(
values=tensor([[1.6869, 0.7893, 0.4076],
[1.0853, 0.6565, 0.6137],
[1.1237, 0.7912, 0.2721],
[2.3566, 1.5263, 0.6008]]),
indices=tensor([[0, 4, 5],
[3, 0, 8],
[8, 6, 5],
[2, 4, 5]]))
idx1 = idx[0]
print(idx1)
idx2 = idx[1]
print(idx2)
tensor([[1.6869, 0.7893, 0.4076],
[1.0853, 0.6565, 0.6137],
[1.1237, 0.7912, 0.2721],
[2.3566, 1.5263, 0.6008]])
tensor([[0, 4, 5],
[3, 0, 8],
[8, 6, 5],
[2, 4, 5]])
label = torch.arange(10)+100
label
tensor([100, 101, 102, 103, 104, 105, 106, 107, 108, 109])
res1 = torch.gather(label.expand(4,10), dim=1, index=idx2.long())
res1
tensor([[100, 104, 105],
[103, 100, 108],
[108, 106, 105],
[102, 104, 105]])
label.expand(4,10)
tensor([[100, 101, 102, 103, 104, 105, 106, 107, 108, 109],
[100, 101, 102, 103, 104, 105, 106, 107, 108, 109],
[100, 101, 102, 103, 104, 105, 106, 107, 108, 109],
[100, 101, 102, 103, 104, 105, 106, 107, 108, 109]])
index=idx2.long()
index
tensor([[0, 4, 5],
[3, 0, 8],
[8, 6, 5],
[2, 4, 5]])
index=idx2
index
tensor([[0, 4, 5],
[3, 0, 8],
[8, 6, 5],
[2, 4, 5]])