torch.gather与tf.gather的区别

在 Tensorflow 和 PyTorch 中,

gather 函数用于按照指定的索引从输入张量中聚合元素并返回新的张量。

具体来说,该函数的作用和用法如下:

TensorFlow 中的 tf.gather() 函数:

tf.gather(params, indices, axis=None, batch_dims=0, name=None)

其中,params 表示输入的张量;indices 表示需要聚合的索引,可以是常量列表或张量;axis 是指定聚合维度的整数,如果不传递,则默认为 0;batch_dims 是指定批次维度数的整数,通常在将多个样本进行聚合时使用;name 是可选的操作名称。

例如,对于张量 params = [[1, 2], [3, 4], [5, 6], [7, 8]],要求按照索引 [0, 2, 3] 在第一维度上进行聚合,则可以使用以下代码:

import tensorflow as tf
params = tf.constant([[1, 2], [3, 4], [5, 6], [7, 8]])
indices = [0, 2, 3]
output = tf.gather(params, indices, axis=0)
print(output)
# Output: [[1 2]
#          [5 6]
#          [7 8]]

PyTorch 中的 torch.gather() 函数:

torch.gather(input, dim, index, out=None)

其中,input 表示输入的张量;dim 表示需要聚合的维度;index 表示需要聚合的索引,可以是常量列表或张量;out 是可选的输出张量。

例如,对于张量 input = torch.tensor([[1, 2], [3, 4], [5, 6], [7, 8]]),要求按照索引 torch.tensor([[0], [2], [3]]) 在第一维度上进行聚合,则可以使用以下代码:

import torch
input = torch.tensor([[1, 2], [3, 4], [5, 6], [7, 8]])
indices = torch.tensor([[0], [2], [3]])
output = torch.gather(input, 0, indices)
print(output)
# Output: tensor([[1],
#                  [5],
#                  [7]])

——————————————————————————————

不同: 

tf.gather的索引为 [0, 2, 3]时,可以在第一维度上进行聚合,但

如果torch.gather的索引也设置为[0, 2, 3],则会报错:

RuntimeError: Index tensor must have the same number of dimensions as input tensor

而若torch.gather的索引设置为[[0], [2], [3]]

就会得到像上面的输出:

# Output: tensor([[1],
#           [5],
#           [7]])

为达到同样的效果,我们先用

torch.unsqueeze函数将tensor添加一个维度从[0, 2, 3]变为[[0], [2], [3]],
再用repeat函数对索引进行更改:
import torch
input = torch.tensor([[1, 2], [3, 4], [5, 6], [7, 8]])
indices = torch.tensor([0, 2, 3])
indices = torch.unsqueeze(indices, dim=1)
indices = indices.repeat(1, input.size(1))
output = torch.gather(input, 0, indices)
print(output)
# Output: tensor([[1, 2],
#                  [5, 6],
#                  [7, 8]])

repeat函数表示对这个索引进行重复,第一个参数为1,表示行不变,第二个参数为

input.size(1),表示列重复input.size(1)遍,即  两遍,得到的indices为:

tensor([[0, 0],
        [2, 2],
        [3, 3]])

当遇到 多维张量 时,我们仍可这样操作

-----------------------------------------------------------------------

by the way,我们也可以对维度,也就是dim/axis进行操作,

我们可以对比一下

input = torch.tensor([[ 1,  2,  3],
                      [ 4,  5,  6],
                      [ 7,  8,  9],
                      [10, 11, 12]])

x = index = torch.tensor([[1], [0]])
index=index.repeat(1, input.size(1))
c = torch.gather(input,dim=0,index=index)
d = torch.gather(input,dim=1,index=x)


#tensor([[4, 5, 6],
#       [1, 2, 3]])


#tensor([[2],
#        [4]])
input = torch.tensor([[ 1,  2,  3],
                      [ 4,  5,  6],
                      [ 7,  8,  9],
                      [10, 11, 12]])

x = index = torch.tensor([[1], [0]])
index=index.repeat(1, input.size(1))
c = torch.gather(input,dim=0,index=index)
d = torch.gather(input,dim=1,index=index)


#tensor([[4, 5, 6],
#        [1, 2, 3]])

#tensor([[2, 2, 2],
#        [4, 4, 4]])

by the way way:

        tf.gather_nd可以替换tf.gather,但是在用gather_nd会引入更多额外参数,对4-d tensor,假设我们想用tf.gather_nd替换tf.gather,就要提取出对应轴的元素,此时的indices就要把想要的元素对应索引组成一个矩阵就可以了。

那不如来思考一个形状为[2,3,4,5] 的parmas,
如何把tf.gather(params,axis = 3,indices=[0,2])用tf.gather_nd来输出同样的结果。

import tensorflow as tf
import numpy as np
a = tf.Variable(tf.random_uniform(shape=(2, 3, 4, 5), name="v"))
nnzs = [0,2]
nnzs = np.asarray(nnzs,"int32")
#initi= np.asarray(initi,"int32")
initi =np.zeros((2,3,4,nnzs.size,4),dtype=np.int)
print(initi.shape)
for i in range(initi.shape[0]):
    for j in range(initi.shape[1]):
        for k in range (initi.shape[2]):
           for l in range(nnzs.size):
                initi[i][j][k][l] =[i,j,k,nnzs[l]]
indices = tf.Variable(initial_value = initi, name="indices")
c = tf.gather_nd(a,initi)
d = tf.gather(a,indices= nnzs,axis =3)
print(c.get_shape())
print(d.get_shape())
e = tf.equal(c,d)
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    print(sess.run([e]))


最后输出一个全为true的array。

猜你喜欢

转载自blog.csdn.net/djdjdhch/article/details/130598827