tf.gather

函数:tf.gather
gather(
    params,
    indices,
    validate_indices=None,
    name=None,
    axis=0
)

定义在:tensorflow/python/ops/array_ops.py

参见指南:张量变换>分割和连接

根据索引从参数轴上收集切片。 

import tensorflow as tf

temp = tf.range(0,10)*10 + tf.constant(1,shape=[10])
temp2 = tf.gather(temp,[1,5,9])

with tf.Session() as sess:

    print sess.run(temp)
    print sess.run(temp2)

猜你喜欢

转载自blog.csdn.net/xijuezhu8128/article/details/81220187
tf
今日推荐