pytorch-index and slice

indexing

import torch
a = torch.rand(4,3,28,28)
a
tensor([[[[0.1683, 0.1747, 0.7508,  ..., 0.2379, 0.2763, 0.9050],
          [0.4666, 0.1574, 0.5551,  ..., 0.6048, 0.1089, 0.1808],
          [0.3799, 0.0634, 0.3741,  ..., 0.6716, 0.7283, 0.7299],
          ...,
          [0.2187, 0.3803, 0.2877,  ..., 0.0674, 0.4328, 0.7230],
          [0.9149, 0.1335, 0.8394,  ..., 0.9956, 0.8328, 0.1046],
          [0.5468, 0.9828, 0.3121,  ..., 0.1648, 0.4289, 0.5186]],


         [[0.2494, 0.5421, 0.0255,  ..., 0.7555, 0.3055, 0.5643],
          [0.3268, 0.5198, 0.7447,  ..., 0.0771, 0.0692, 0.4214],
          [0.8027, 0.3809, 0.0480,  ..., 0.7824, 0.0415, 0.7449],
          ...,
          [0.0783, 0.6822, 0.4180,  ..., 0.5414, 0.7171, 0.3447],
          [0.6471, 0.2972, 0.0571,  ..., 0.0784, 0.2780, 0.9263],
          [0.7520, 0.4694, 0.4576,  ..., 0.6423, 0.6848, 0.5141]]]])
a[0].shape
torch.Size([3, 28, 28])
a[0]
tensor([[[0.1683, 0.1747, 0.7508,  ..., 0.2379, 0.2763, 0.9050],
         [0.4666, 0.1574, 0.5551,  ..., 0.6048, 0.1089, 0.1808],
         [0.3799, 0.0634, 0.3741,  ..., 0.6716, 0.7283, 0.7299],
         ...,
         [0.2187, 0.3803, 0.2877,  ..., 0.0674, 0.4328, 0.7230],
         [0.9149, 0.1335, 0.8394,  ..., 0.9956, 0.8328, 0.1046],
         [0.5468, 0.9828, 0.3121,  ..., 0.1648, 0.4289, 0.5186]],

        [[0.0309, 0.6258, 0.1917,  ..., 0.3253, 0.4002, 0.6112],
         [0.9583, 0.4674, 0.9571,  ..., 0.8539, 0.9874, 0.9117],
         [0.9209, 0.0456, 0.7559,  ..., 0.5963, 0.7595, 0.6844],
         ...,
         [0.3391, 0.5333, 0.1412,  ..., 0.4589, 0.1842, 0.6238],
         [0.8586, 0.6629, 0.7139,  ..., 0.0109, 0.1477, 0.4978],
         [0.5280, 0.5506, 0.5176,  ..., 0.2900, 0.1188, 0.6285]],

        [[0.0303, 0.5729, 0.5613,  ..., 0.0277, 0.1602, 0.3827],
         [0.4207, 0.1314, 0.2211,  ..., 0.2983, 0.1922, 0.7292],
         [0.5850, 0.4933, 0.5886,  ..., 0.5713, 0.8765, 0.1331],
         ...,
         [0.3938, 0.5237, 0.6337,  ..., 0.8795, 0.6323, 0.7085],
         [0.1957, 0.9536, 0.2277,  ..., 0.7959, 0.8169, 0.8741],
         [0.5623, 0.1302, 0.8132,  ..., 0.0622, 0.2000, 0.4052]]])
a[0,0].shape
torch.Size([28, 28])
a[0,0,2,4]
tensor(0.5822)

select first/last N

a.shape
torch.Size([4, 3, 28, 28])
a[:,2].shape
torch.Size([4, 28, 28])
a[:2].shape # 0,1
torch.Size([2, 3, 28, 28])
a[:2, :1, :, :].shape # 0,1 | 0 | all | all
torch.Size([2, 1, 28, 28])
a[:2, 1, :, :].shape # 0,1 | 0 | all | all
torch.Size([2, 28, 28])
a[:2, -1:, : , :].shape # 0,1 | 3 | all | all  索引从左到右
torch.Size([2, 1, 28, 28])

select by steps

a[:, :, 0:28:2, 0:28:2].shape # all | all | step=2 | step=2
torch.Size([4, 3, 14, 14])
a[:, :, ::2, ::2].shape
torch.Size([4, 3, 14, 14])

select by specific index start:stop:step

a.shape
torch.Size([4, 3, 28, 28])
# index_select()第一个数字代表第几个维度,维度索引为(0,1,2,3);
# 其后面的那个参数代表的意思为 取该维度内的范围。
a.index_select(0, torch.tensor([0,2])).shape
torch.Size([2, 3, 28, 28])
a.index_select(1, torch.tensor([1,2])).shape
torch.Size([4, 2, 28, 28])
a.index_select(2, torch.arange(28)).shape
torch.Size([4, 3, 28, 28])
a.index_select(2, torch.arange(8)).shape
torch.Size([4, 3, 8, 28])
torch.arange(8)
tensor([0, 1, 2, 3, 4, 5, 6, 7])

"..." stands for any dimension

a.shape
torch.Size([4, 3, 28, 28])
a[...].shape
torch.Size([4, 3, 28, 28])
a[0, ...].shape
torch.Size([3, 28, 28])
a[:, 1, ...].shape
torch.Size([4, 28, 28])
a[..., :2].shape # 0,1
torch.Size([4, 3, 28, 2])

select by mask

x = torch.randn(3, 4)
x
tensor([[-2.0654, -0.2259,  0.5090, -0.2166],
        [-1.0956, -0.6469,  0.0136,  0.1804],
        [ 0.5863,  0.9568, -0.5896,  0.2588]])
mask = x.ge(0.5) 
mask
tensor([[False, False,  True, False],
        [False, False, False, False],
        [ True,  True, False, False]])
torch.masked_select(x, mask)
tensor([0.5090, 0.5863, 0.9568])
torch.masked_select(x, mask).shape
torch.Size([3])

select by flatten index

# 0,1,2,3,4,5
# 4,3,5,6,7,8
src = torch.tensor([[4,3,5],[6,7,8]])
src
tensor([[4, 3, 5],
        [6, 7, 8]])
torch.take(src, torch.tensor([0,2,5])) # 先打平再操作
tensor([4, 5, 8])

Guess you like

Origin blog.csdn.net/MasterCayman/article/details/109393936