torch.index_select()

Function form:

index_select(

 dim,

 index

)

parameter:

  • dim: Indicates the selection of data from the first dimension, the type is an int value;
  • index: indicates where to pick data from the first parameter dimension, the type is an instance of the torch.Tensor class;

I just started to learn pytorch and encountered index_select(). At first I didn't understand the meaning of several parameters. Later I checked the information and I understood a little bit.

a = torch.linspace(1, 12, steps=12).view(3, 4)

print(a)

b = torch.index_select(a, 0, torch.tensor([0, 2]))

print(b)

print(a.index_select(0, torch.tensor([0, 2])))

c = torch.index_select(a, 1, torch.tensor([1, 3]))

print(c)

First define a tensor, here the linspace and view methods are used.

The first parameter is the object of the index, the second parameter 0 means indexing by row, 1 means indexing by column, and the third parameter is a tensor, which is the index number. For example, tensor[0, 2] in b means the first In row 0 and row 2, tensor[1, 3] in c means column 1 and column 3.

The output is as follows:

tensor([[ 1.,  2.,  3.,  4.],
        [ 5.,  6.,  7.,  8.],
        [ 9., 10., 11., 12.]])
tensor([[ 1.,  2.,  3.,  4.],
        [ 9., 10., 11., 12.]])
tensor([[ 1.,  2.,  3.,  4.],
        [ 9., 10., 11., 12.]])
tensor([[ 2.,  4.],
        [ 6.,  8.],
        [10., 12.]])

Function: Select data from a specified position in a certain dimension of the tensor.

Code example:

t = torch.arange(24).reshape(2, 3, 4) # 初始化一个tensor,从0到23,形状为(2,3,4)

print("t--->", t)

  

index = torch.tensor([1, 2]) # 要选取数据的位置

print("index--->", index)

  

data1 = t.index_select(1, index) # 第一个参数:从第1维挑选, 第二个参数:从该维中挑选的位置

print("data1--->", data1)

  

data2 = t.index_select(2, index) # 第一个参数:从第2维挑选, 第二个参数:从该维中挑选的位置

print("data2--->", data2)

operation result: 

t---> tensor([[[ 0,  1,  2,  3],
               [ 4,  5,  6,  7],
               [ 8,  9, 10, 11]],
 
              [[12, 13, 14, 15],
               [16, 17, 18, 19],
               [20, 21, 22, 23]]])
 
index---> tensor([1, 2])
 
data1---> tensor([[[ 4,  5,  6,  7],
                   [ 8,  9, 10, 11]],
 
                  [[16, 17, 18, 19],
                   [20, 21, 22, 23]]])
 
data2---> tensor([[[ 1,  2],
                   [ 5,  6],
                   [ 9, 10]],
 
                  [[13, 14],
                   [17, 18],
                   [21, 22]]])

 

Undertake programming in Matlab, Python and C++, machine learning, computer vision theory implementation and guidance, both undergraduate and master's degree, salted fish trading, professional answers please go to know, please contact QQ number 757160542 for details, if you are the one.

 

Guess you like

Origin blog.csdn.net/weixin_36670529/article/details/113817994