tensorflow实现分组取top_k

def test_segment():
    """
    分组取出各个类别的最大概率
    :return:
    """
    k = 3
    label_tags = tf.Variable(tf.constant(['label1', 'label2', 'label3', 'label4']), trainable=False)
    label_class = tf.Variable(tf.constant(['3', '2', '3', '2', '3', '0']), trainable=False)
    _labels = tf.string_to_number(label_class, tf.int32)
    _scores = tf.Variable(tf.constant((0.8, 0.2, 0.1, 0.9, 0.87, 0.98)), trainable=False)

    # 先过滤不符合要求的结果,此处需要实际评测过滤和不过滤的耗时
    _indices = tf.where(tf.greater_equal(_scores, 0.8))
    _scores = tf.gather(_scores, _indices)
    _labels = tf.gather(_labels, _indices)

    segment_scores = tf.unsorted_segment_max(_scores, _labels, tf.size(label_class))
    final_scores, final_idx = tf.nn.top_k(segment_scores, k=tf.minimum(k, tf.size(label_class)))
    final_labels = tf.gather(label_tags, final_idx)
    with tf.Session() as sess:
        tf.global_variables_initializer().run()
        print(_scores.eval())
        print(_labels.eval())
        print(final_scores.eval())
        print(final_labels.eval())
if __name__ == '__main__':
    test_segment()
发布了127 篇原创文章 · 获赞 10 · 访问量 24万+

猜你喜欢

转载自blog.csdn.net/u012599545/article/details/89438878
今日推荐