Tensorflow argmax函数详解

def argmax(self, axis=None, fill_value=None, out=None):
    返回沿着某个维度最大值的位置
    Returns array of indices of the maximum values along the given axis.
    Masked values are treated as if they had the value fill_value.

    Parameters
    ----------
    axis : {None, integer}
        If None, the index is into the flattened array, otherwise along
        the specified axis
    fill_value : {var}, optional
        Value used to fill in the masked values.  If None, the output of
        maximum_fill_value(self._data) is used instead.
    out : {None, array}, optional
        Array into which the result can be placed. Its type is preserved
        and it must be of the right shape to hold the output.

    Returns
    -------
    index_array : {integer_array}

    Examples
    --------
    >>> a = np.arange(6).reshape(2,3)
    >>> a.argmax()
    5
    >>> a.argmax(0)
    array([1, 1, 1])
    >>> a.argmax(1)
    array([2, 2])

    """
    if fill_value is None:
        fill_value = maximum_fill_value(self._data)
    d = self.filled(fill_value).view(ndarray)
    return d.argmax(axis, out=out)

看下面的例子就更明白了:

tf.argmax | tf.argmin

tf.argmax(input=tensor,dimention=axis) 找到给定的张量tensor中在指定轴axis上的最大值/最小值的位置。

a=tf.get_variable(name='a',
                  shape=[3,4],
                  dtype=tf.float32,
                  initializer=tf.random_uniform_initializer(minval=-1,maxval=1))
b=tf.argmax(input=a,dimension=0)
c=tf.argmax(input=a,dimension=1)
sess = tf.InteractiveSession()
sess.run(tf.initialize_all_variables())
print(sess.run(a))
#[[ 0.04261756 -0.34297419 -0.87816691 -0.15430689]
# [ 0.18663144  0.86972666 -0.06103253  0.38307118]
# [ 0.84588599 -0.45432305 -0.39736366  0.38526249]]
print(sess.run(b))
#[2 1 1 2]
print(sess.run(c))
#[0 1 0]
部门内容来源网络

猜你喜欢

转载自blog.csdn.net/liyaoqing/article/details/54020202