tf.gather,tf.range()的详解

在讲解这个之前,我们首先讲一下tf.range(),因为这两个一般都是在一起用的

tf.range()

其和python中的range()的用法基本一样,只不过这里返回的是一个1-D的tensor
tf.range(limit, delta=1, dtype=None, name=‘range’)
tf.range(start, limit, delta=1, dtype=None, name=‘range’)

'''
Args:
	start: A 0-D Tensor (scalar). Acts as first entry in the range if limit is not None; otherwise, acts as range limit and first entry defaults to 0.
	limit: A 0-D Tensor (scalar). Upper limit of sequence, exclusive. If None, defaults to the value of start while the first entry of the range defaults to 0.
	delta: A 0-D Tensor (scalar). Number that increments start. Defaults to 1.
	dtype: The type of the elements of the resulting tensor.
	name: A name for the operation. Defaults to "range".
Returns:
	An 1-D Tensor of type dtype.
'''

tf.gather

该接口的作用:就是抽取出params的第axis维度上在indices里面所有的index(看后面的例子,就会懂)

tf.gather(
    params,
    indices,
    validate_indices=None,
    name=None,
    axis=0
)

'''
Args:
	params: A Tensor. The tensor from which to gather values. Must be at least rank axis + 1.
	indices: A Tensor. Must be one of the following types: int32, int64. Index tensor. Must be in range [0, params.shape[axis]).
	axis: A Tensor. Must be one of the following types: int32, int64. The axis in params to gather indices from. Defaults to the first dimension. Supports negative indexes.
	name: A name for the operation (optional).
Returns:
	A Tensor. Has the same type as params.
'''
说明
参数
  • params: A Tensor.
  • indices: A Tensor. types必须是: int32, int64. 里面的每一个元素大小必须在 [0, params.shape[axis])范围内.
  • axis: 维度。沿着params的哪一个维度进行抽取indices
返回

返回的是一个tensor

帮助理解图

在这里插入图片描述

例子1

代码
import tensorflow as tf

Params = tf.range(0,10)*10
a = tf.gather(Params,[0,5,9])

with tf.Session() as sess:
    print("Params:  \n",sess.run(Params))
    print("抽取的结果: \n",sess.run(a))


输出

Params:  
        [ 0 10 20 30 40 50 60 70 80 90]
抽取的结果: 
      [ 0 50 90]

例子2

代码
import tensorflow as tf
Params=tf.Variable(tf.random_normal([2,3,4]))
indicxs_0=[0,1]
indicxs_1=[0,2]
indicxs_2=[2,3]

gather_0=tf.gather(params=Params,indices=indicxs_0,axis=0)
gather_1=tf.gather(params=Params,indices=indicxs_1,axis=1)
gather_2=tf.gather(params=Params,indices=indicxs_2,axis=2)


with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    print("Params :\n      ",sess.run(Params))
    print("沿着第O维度抽取第0,1个:    \n",sess.run(gather_0))
    print("沿着第1维度抽取第0,3个:    \n",sess.run(gather_1))
    print("沿着第2维度抽取第3,4个:    \n",sess.run(gather_2))

输出

Params :
       [
       [[ 0.78150964  2.09648061  2.37558031  1.20743346]
		[-1.12413085 -0.66349769  1.15486336 -1.17151475]
		 [ 0.0476133  -0.09292984 -0.29620713  0.70557141]]

		 [[-1.34968698 -0.2931003  -1.94950449 -0.27036974]
		  [ 0.27591622 -0.19094539 -0.56113148  0.55863774]
		  [-0.48273012 -0.7819376   0.3261987  -0.97833097]]
		  ]
沿着第O维度抽取第0,1个:    
 [[[ 0.78150964  2.09648061  2.37558031  1.20743346]
  [-1.12413085 -0.66349769  1.15486336 -1.17151475]
  [ 0.0476133  -0.09292984 -0.29620713  0.70557141]]

 [[-1.34968698 -0.2931003  -1.94950449 -0.27036974]
  [ 0.27591622 -0.19094539 -0.56113148  0.55863774]
  [-0.48273012 -0.7819376   0.3261987  -0.97833097]]]
沿着第1维度抽取第0,3个:    
 [[[ 0.78150964  2.09648061  2.37558031  1.20743346]
  [ 0.0476133  -0.09292984 -0.29620713  0.70557141]]

 [[-1.34968698 -0.2931003  -1.94950449 -0.27036974]
  [-0.48273012 -0.7819376   0.3261987  -0.97833097]]]
沿着第2维度抽取第3,4个:    
 [[[ 2.37558031  1.20743346]
  [ 1.15486336 -1.17151475]
  [-0.29620713  0.70557141]]

 [[-1.94950449 -0.27036974]
  [-0.56113148  0.55863774]
  [ 0.3261987  -0.97833097]]]

猜你喜欢

转载自blog.csdn.net/qq_32806793/article/details/85324531
tf
今日推荐