tf.gather(params, indices, validate_indices=None, name=None, axis=0)
Gather slices from params
axis axis
according to indices
.
从’params’的’axis’维根据’indices’的参数值获取切片。就是在axis维根据indices取某些值。
import tensorflow as tf
temp = tf.range(0,10)*10 + tf.constant(1,shape=[10])
temp2 = tf.gather(temp,[1,5,9])
with tf.Session() as sess:
print sess.run(temp)
print sess.run(temp2)
输出
[ 1 11 21 31 41 51 61 71 81 91]
[11 51 91]