tf.gather的使用说明

tf.gather(params, indices, validate_indices=None, name=None, axis=0)
Gather slices from params axis axis according to indices.

从’params’的’axis’维根据’indices’的参数值获取切片。就是在axis维根据indices取某些值。
在这里插入图片描述

import tensorflow as tf

temp = tf.range(0,10)*10 + tf.constant(1,shape=[10])
temp2 = tf.gather(temp,[1,5,9])

with tf.Session() as sess:

    print sess.run(temp)
    print sess.run(temp2)


输出
[ 1 11 21 31 41 51 61 71 81 91]
[11 51 91]

猜你喜欢

转载自blog.csdn.net/Harpoon_fly/article/details/85001580
今日推荐