Tensorflow tf.argmax(vec, 1) 函数

版权声明:潘广宇博客, https://blog.csdn.net/panguangyuu/article/details/87554714

【函数功能】

返回vec中的最大值所在的索引号,如果vec是一个向量,则返回一个索引号,若是一个矩阵,则返回一个向量,该向量的每一维都对应矩阵行的最大值元素的索引号。

【函数实例】

import tensorflow as tf
import numpy as np

A = [[1,2,3,4,5]]
B = [[1,4,6], [5,8,1]]
C = [[1,3,5], [2,5,1], [3,7,10]]
 
with tf.Session() as sess:
    print(sess.run(tf.argmax(A, 1)))
    print(sess.run(tf.argmax(B, 1)))
    print(sess.run(tf.argmax(C, 1)))

输出:

[4]             # A中索引号为4的元素最大 
[2 1]           # B中第一行索引号为2的数6最大,第二行索引号为1对应的8最大
[2 1 2]         # C同理

猜你喜欢

转载自blog.csdn.net/panguangyuu/article/details/87554714