pytorch函数之torch.index_select

torch.index_select函数,顾名思义就是根据index索引在input输入张量中选择某些特定的元素,下面介绍该函数的参数。

torch.index_select(input, dim, index, out=None):

  1. input:输入Tensor,在该Tensor上根据index和dim进行切片;
  2. dim:切片维度,在input这个Tensor的哪个维度上进行dinex索引;
  3. index:Tensor.LongTensor类型的1-D Tensor,在dim维度上需要索引的下标(自己尝试过非1-D的index,结果报错 Index is supposed to be
    1-dimensional,如有不对欢迎指正);
  4. out:用来承载函数的返回值(也可以直接用变量x=torch.index_select(input, dim, index)进行承载,不需要out参数)
import torch

x = torch.rand(3,5)
index = torch.LongTensor([2,0])

# 如果想在x的第一个维度上选择x[2]和x[0]
y = torch.index_select(x, dim=0, index=index)

# 如果想在x的第二个维度上选择,即x[...,2]和x[...,0]
y = torch.index_select(x, dim=1, index=index)

# 另外,也可以用以下方法
y = x.new()
torch.index_select(x, dim=0, index=index, out=y)
发布了9 篇原创文章 · 获赞 0 · 访问量 35

猜你喜欢

转载自blog.csdn.net/qq_41092190/article/details/105403943