tf.argmax的axis理解

版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/guotong1988/article/details/82951319
import tensorflow as tf
tf.enable_eager_execution()

value = [[0, 1, 2, 3],
         [4, 5, 6, 7]]
init = tf.constant_initializer(value)
x = tf.get_variable('x', shape=[2,4], initializer=init)

print(tf.argmax(x,axis=0)) # 列
print(tf.argmax(x,axis=1)) # 行

打印结果:
tf.Tensor([1 1 1 1], shape=(4,), dtype=int64)
tf.Tensor([3 3], shape=(2,), dtype=int64)

猜你喜欢

转载自blog.csdn.net/guotong1988/article/details/82951319
今日推荐