tensorflow函数——tf.argmax(x, axis)

tf.argmax(x, axis)
这个函数是衍生于numpy库中的numpy.argmax()。tf.argmax(x, axis) 函数返回的是张量x中元素最大值对应的index号(而不是元素最大值)

  • 参数x: 是一个函数的输入,是一个张量。例如a=[[1,2,3],[4,5,6],[7,8,9]] 这样的3X3的矩阵
  • 参数axis: 英文为“坐标轴”意思,它有2个值,当axis=0,表示要从每一列中寻找最大元素值对应的index号;当axis=1,表示从每一行中寻找最大元素值对应的index号

在python中可以试着验证一下:

import tensorflow as tf
a=[[1,0,3],[8,-3,6],[5,1,7]]
with tf.Session() as sess:
     print(sess.run(tf.argmax(a,1)))#按照行去查找每一行中最大值对应的index
     print(sess.run(tf.argmax(a,0)))#按照列去查找每一列中最大值对应的index

运行结果如下:
这里写图片描述

猜你喜欢

转载自blog.csdn.net/weixin_40769885/article/details/82226168
今日推荐