PyTorch中Tensor的高阶操作

版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/weixin_44613063/article/details/89741267

where


gather

沿给定轴 dim,将输入索引张量 index 指定位置的值进行聚合

举个例子:

>>> t = torch.Tensor([[1,2],[3,4]])
>>> torch.gather(t, 1, torch.LongTensor([[0,0],[1,0]]))
tensor([[1., 1.],
        [4., 3.]])

猜你喜欢

转载自blog.csdn.net/weixin_44613063/article/details/89741267