tf.gather用法

tf.gather(params, indices, validate_indices=None, name=None, axis=0)
Gather slices from params axis axis according to indices.
从’params’的’axis’维根据’indices’的参数值获取切片。就是在axis维根据indices取某些值。
测试代码:

import tensorflow as tf
import numpy as np
print("\n先测试一维张量\n")
t=np.random.randint(1,10,5)
g1=tf.gather(t,[2,1,4])
sess=tf.Session()
print(t)
print(sess.run(g1))
print("\n再测试二维张量\n")
t=np.random.randint(1,10,[4,5])
g2=tf.gather(t,[1,2,2],axis=0)
g3=tf.gather(t,[1,2,2],axis=1)
print(t)
print(sess.run(g2))
print(sess.run(g3))

结果如下:

先测试一维张量
[7 4 7 1 3]
[7 4 3]
再测试二维张量
[[5 5 7 4 3]
 [8 7 6 5 2]
 [6 9 4 4 8]
 [7 3 3 2 2]]
[[8 7 6 5 2]
 [6 9 4 4 8]
 [6 9 4 4 8]]
[[5 7 7]
 [7 6 6]
 [9 4 4]
 [3 3 3]]

猜你喜欢

转载自blog.csdn.net/qq_31150463/article/details/84194006