pytorch文档学习

(1)

torch.gather(input, dim, index, out=None) → Tensor

input是输入tensor;dim维度(二维矩阵,0表示在列的方向,行的索引;1表示在行的方向,列的索引);index索引
out[i][j][k] = input[index[i][j][k]][j][k] # dim=0
out[i][j][k] = input[i][index[i][j][k]][k] # dim=1
out[i][j][k] = input[i][j][index[i][j][k]] # dim=2

二维矩阵
[[1,2],[3,4]] dim=1 [[0,0],[1,0]] -------------[[1,1],[43]]
在行的方向上[1,2]对应索引[0,0]就为[1,1]

(2)

torch.index_select(input, dim, index, out=None) → Tensor
>>> x = torch.randn(3, 4)
>>> x

 1.2045  2.4084  0.4001  1.1372
 0.5596  1.5677  0.6219 -0.7954
 1.3635 -1.2313 -0.5414 -1.8478
[torch.FloatTensor of size 3x4]

>>> indices = torch.LongTensor([0, 2])
>>> torch.index_select(x, 0, indices)

 1.2045  2.4084  0.4001  1.1372
 1.3635 -1.2313 -0.5414 -1.8478
[torch.FloatTensor of size 2x4]

>>> torch.index_select(x, 1, indices)

 1.2045  0.4001
 0.5596  0.6219
 1.3635 -0.5414

dim=0时,第一列[1.2045 0.5596 1.3635] 取索引[0,2] 生成[1.2045 1.3635]

猜你喜欢

转载自blog.csdn.net/qq_32560769/article/details/85764466
今日推荐