def test_segment(): """ 分组取出各个类别的最大概率 :return: """ k = 3 label_tags = tf.Variable(tf.constant(['label1', 'label2', 'label3', 'label4']), trainable=False) label_class = tf.Variable(tf.constant(['3', '2', '3', '2', '3', '0']), trainable=False) _labels = tf.string_to_number(label_class, tf.int32) _scores = tf.Variable(tf.constant((0.8, 0.2, 0.1, 0.9, 0.87, 0.98)), trainable=False) # 先过滤不符合要求的结果,此处需要实际评测过滤和不过滤的耗时 _indices = tf.where(tf.greater_equal(_scores, 0.8)) _scores = tf.gather(_scores, _indices) _labels = tf.gather(_labels, _indices) segment_scores = tf.unsorted_segment_max(_scores, _labels, tf.size(label_class)) final_scores, final_idx = tf.nn.top_k(segment_scores, k=tf.minimum(k, tf.size(label_class))) final_labels = tf.gather(label_tags, final_idx) with tf.Session() as sess: tf.global_variables_initializer().run() print(_scores.eval()) print(_labels.eval()) print(final_scores.eval()) print(final_labels.eval())
if __name__ == '__main__': test_segment()