tensorflow对数组排序

有时我们会遇到tensor域下的数组排序,比如按照一定规则对输入排序。

import tensorflow as tf
import numpy as np

a = tf.placeholder(tf.int32, shape=(3,2))


# bb = tf.constant(a) # the array
reordered = tf.gather(a, tf.nn.top_k(a[:, 0], k=3).indices)   # 按照输入的第一个维度排序,选取top3的值
value_ = tf.nn.top_k(a[:, 0], k=3).values
indices_ = tf.nn.top_k(a[:, 0], k=3).indices
'''
tf.nn.top_k(
    input,
    k=1,
    sorted=True,
    name=None
)
top_k返回值:
top_k(...).values: The k largest elements along each last dimensional slice.(返回对应的值)
top_k(...).indices: The indices of values within the last dimension of input(返回索引)

-----------------
tf.gather(
    params,
    indices,
    validate_indices=None,
    name=None,
    axis=0
)
根据对应索引indices把对应元素取出来
'''
 

feed_dict = {a:np.array([[1, 2], [3, 4], [2, 2]])}
sess = tf.Session()
_in, out, v, i = sess.run([a, reordered, value_, indices_], feed_dict=feed_dict)
print('in:\n',_in, '\nout:\n', out, '\nvalue:\n', v, '\nindices:\n', i)
'''
>>>in:
 [[1 2]
 [3 4]
 [2 2]]
out:
 [[3 4]
 [2 2]
 [1 2]]
value:
 [3 2 1]
indices:
 [1 2 0]

'''
 

猜你喜欢

转载自blog.csdn.net/md2017/article/details/81182423