pytorch中index_select的用法

最近开始接触pytorch,发现torch.index_select不是很好理解,就查了一下文档:

torch.index_select(input, dim, index, out=None) → Tensor
  • input :输入一个张量

  • dim :索引所依赖的维度

  •  index :索引的index

  •  out :返回的张量,默认为None

torch.index_select()在input的张量在dim的维度上按照index索引并返回一个新的张量。这样子不太好理解,我们拿一个例子来看看:

导入torch包,创建一个3×4的二维张量x

传入x,并在第0维上(这里是二维,所以直接是在行上)按照[0,2]索引(也就是选择index为0和2的行,也就是第1行和第三行)

同理,这里只是把维度改成了1,这里就是索引x的第一列和第三列。

这样子写也能实现一样的效果。

另外,官方文档上面有这样一段描述:

意思就是说,返回的张量和输入的张量具有相同的维数。然后返回的张量中,第dim维(就是代码中的dim)跟的大小跟代码中的index一样长,其它维的大小跟输入张量保持一致,还是上面的例子如下图:

输入张量是3×4,这是在第1维(此处就是列)进行索引,indices长度为2。输出张量是第0维是3,与输入张量一致,第0维是2,与indices一致。

发布了5 篇原创文章 · 获赞 19 · 访问量 1066

猜你喜欢

转载自blog.csdn.net/weixin_44976373/article/details/104564663