torch.gather的三维实例

版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/sinat_26114733/article/details/88088077
>>> a
tensor([[ 0.9918,  0.4911,  1.4912, -1.8491],
        [ 0.1257, -0.4406,  0.3371,  0.1205],
        [ 0.3064, -0.8198,  1.2851,  0.2486]])
>>> b
tensor([[0, 1],
        [1, 2],
        [2, 2]])
>>> a.unsqueeze(1).expand(3,2,4).gather(dim=0,index=b.unsqueeze(2).expand(3,2,4))
tensor([[[ 0.9918,  0.4911,  1.4912, -1.8491],
         [ 0.1257, -0.4406,  0.3371,  0.1205]],

        [[ 0.1257, -0.4406,  0.3371,  0.1205],
         [ 0.3064, -0.8198,  1.2851,  0.2486]],

        [[ 0.3064, -0.8198,  1.2851,  0.2486],
         [ 0.3064, -0.8198,  1.2851,  0.2486]]])

猜你喜欢

转载自blog.csdn.net/sinat_26114733/article/details/88088077
今日推荐