tf.where实例

import tensorflow as tf

score_threshold = 3
x = tf.constant([[1],[2],[3],[4],[5]])
index = tf.where(tf.greater(x, score_threshold))
with tf.Session() as sess:
    y = sess.run(index)
    print(y)

'''
结果:
[[3 0]
 [4 0]]
'''

再如

import tensorflow as tf

score_threshold = 3
x = tf.constant([1,2,3,4,5])
index = tf.where(tf.greater(x, score_threshold))
with tf.Session() as sess:
    y = sess.run(index)
    print(y)

'''
结果:
[[3]
 [4]]
'''

猜你喜欢

转载自blog.csdn.net/weixin_43331915/article/details/83386545