tensorflow中的tf.argmax()

tf.argmax就是返回最大的那个数值的索引值。tf.argmax即np.argmax>

tf.argmax(array, 1)和tf.argmax(array, 0)有区别(分别是axis为1和0)。

例子:

test = np.array([[1, 2, 3], [2, 3, 4], [5, 4, 3], [8, 7, 2]])
np.argmax(test, 0)   #输出:array([3, 3, 1]
np.argmax(test, 1)   #输出:array([2, 2, 0, 0]

解释:

tf.argmax(array, 0)是比较不同数组中相同位置的数字。比如下面是比较每一列的数字,返回每一列数字中最大值的索引。

当数组长度不一致时,axis=0,变成了每个数组的和的比较。

test[0] = array([1, 2, 3])
test[1] = array([2, 3, 4])
test[2] = array([5, 4, 3])
test[3] = array([8, 7, 2])
# output   :    [3, 3, 1]    
tf.argmax(array, 0)是比较同一数组中的数字。比如下面是比较每一行数字,返回每一行数字中最大值的索引。
test[0] = array([1, 2, 3])  #2
test[1] = array([2, 3, 4])  #2
test[2] = array([5, 4, 3])  #0
test[3] = array([8, 7, 2])  #0


猜你喜欢

转载自blog.csdn.net/shashaqingmuzi/article/details/80256468
今日推荐