python中的高级切片(索引为数组)

今天写代码遇到了一个需求,那就是我有一个大小为(10,20,30,6)大小的数组a,还有一个大小为(10, 20, 4)大小的索引indice, 索引中是4个0到29不重复的整数,我希望利用这个indice岁数组a进行切片

a = torch.arange(10 * 20 * 30 * 6).reshape(10, 20, 30, 6)
indice = torch.randint(1, 30, (10, 20, 4))

直接使用a[indice]无法达到想要的结果
a[:,:,indice]得到的结果大小为(10,20,10,20,6),也不正确
解决方法

a[torch.arange(a.shape[0])[:, None, None], torch.arange(a.shape[1])[None, :, None], indice]

即在前两个维度添加2个辅助的遍历

猜你喜欢

转载自blog.csdn.net/qq_43666068/article/details/132054252