对tensorflow中张量tensor的理解与tf.argmax()函数的用法

版权声明:站在巨人的肩膀上学习。 https://blog.csdn.net/zgcr654321/article/details/83652272

对tensorflow中张量tensor的理解:

一维张量:

如a=[1., 2., 3., 0., 9., ],其shape为(5,),故当我们选择维度0时(张量的维度总是从第0个维度开始),实际上是在a的最外层括号上进行操作。

我们画图来表示:

二维张量:

如b=[[1, 2, 3], [3, 2, 1], [4, 5, 6], [6, 5, 4]],其shape为(4, 3),当我们选择维度0时,是在b的最外层括号上进行操作;当我们选择维度1时,是在b的第二层括号内进行操作。

如果我们把b写成矩阵的形式的话,当我们选择维度0时,就是在矩阵的列上进行操作;当我们选择维度1时,就是在矩阵的行上进行操作。

我们画一幅图来表示:

三维张量:

如c = tf.constant([[[1, 3, 1, 4], [5, 3, 2, 1], [1, 2, 2, 4]], [[4, 2, 3, 1], [5, 3, 1, 9], [3, 7, 2, 3]]]),其shape为(2,3,4)。类似地,当我们选择维度0时,也是在最外层括号上进行操作;当我们选择维度1时,是在第二层括号上进行操作;当我们选择维度2时,是在最里层的括号上进行操作。

为了更加直观,我们画一幅图来表示:

tf.argmax()函数的用法:

函数形式:

tf.argmax(input=tensor,dimention=axis)

该函数返回指定的张量tensor中指定维度上的最大值/最小值的下标(即位置)。dimension=0则查找tensor上维度0上的最大值(如果是一维数组,维度0就是行,如果是二维数组,维度0就是列)dimension=1则查找tensor上维度1上的最大值(如果是二维数组,维度1就是行) 。dimension = 2、3、4...,即为多维张量时,按同理推断。

如果tensor是一个向量,那就返回一个值,如果是一个矩阵,那就返回一个向量,这个向量的每一个维度都是相对应矩阵指定的维度上的最大值元素的索引号。

以上面的a,b,c张量举例:

import tensorflow as tf

with tf.Session() as sess:
	print("创建一个一维张量a:")
	a = tf.constant([1., 2., 3., 0., 9., ])
	print(a, a.shape)
	print("创建一个二维张量b,b是一个4X3矩阵,矩阵的每个元素是一个值,b有2个维度:")
	b = tf.constant([[1, 2, 3], [3, 2, 1], [4, 5, 6], [6, 5, 4]])
	print(b, b.shape)
	print("创建一个三维张量c,c是一个2X3矩阵,矩阵的每个元素时一个有4个元素的数组,c有3个维度:")
	c = tf.constant([[[1, 3, 1, 4], [5, 3, 2, 1], [1, 2, 2, 4]],
					 [[4, 2, 3, 1], [5, 3, 1, 9], [3, 7, 2, 3]]])
	print(c, c.shape)
	print("查找一维张量a的最大值的下标:")
	print(sess.run(tf.argmax(a, 0)))
	print("查找二维张量b的每列最大值的下标:")
	print(sess.run(tf.argmax(b, 0)))
	print("查找二维张量b的每行最大值的下标:")
	print(sess.run(tf.argmax(b, 1)))
	print("查找三维张量c的维度0上最大值的下标:")
	print(sess.run(tf.argmax(c, 0)))
	print("查找三维张量c的维度1上最大值的下标:")
	print(sess.run(tf.argmax(c, 1)))
	print("查找三维张量c的维度2上最大值的下标:")
	print(sess.run(tf.argmax(c, 2)))

运行结果如下:

创建一个一维张量a:
Tensor("Const:0", shape=(5,), dtype=float32) (5,)
创建一个二维张量b,b是一个4X3矩阵,矩阵的每个元素是一个值,b有2个维度:
Tensor("Const_1:0", shape=(4, 3), dtype=int32) (4, 3)
创建一个三维张量c,c是一个2X3矩阵,矩阵的每个元素时一个有4个元素的数组,c有3个维度:
Tensor("Const_2:0", shape=(2, 3, 4), dtype=int32) (2, 3, 4)
查找一维张量a的最大值的下标:
4
查找二维张量b的每列最大值的下标:
[3 2 2]
查找二维张量b的每行最大值的下标:
[2 0 2 0]
查找三维张量c的维度0上最大值的下标:
[[1 0 1 0]
 [0 0 0 1]
 [1 1 0 0]]
查找三维张量c的维度1上最大值的下标:
[[1 0 1 0]
 [1 2 0 1]]
查找三维张量c的维度2上最大值的下标:
[[3 0 3]
 [0 3 1]]

Process finished with exit code 0

对于一维数组,dimension=0时tf.argmax函数就是对唯一的一个维度行上取最大值下标;

对于二维数组,dimension=0时tf.argmax函数是对行上取最大值下标,dimension=1时tf.argmax函数是对列上取最大值下标;

对于三维数组c(shape=(2,3,4)),写出来是这样的形式:

如上图,各个维度的操作方向如图所示,因此:

dimension=0时tf.argmax函数是对第一种c的写法的纵向的列上取最大值下标,我们可以发现一共取了12个值;

dimension=1时tf.argmax函数是对第二种c的写法的纵向的列上取最大值下标(注意是对第二层括号内),因此一共取了8个值;

dimension=2时tf.argmax函数是对第一种c的写法的行方向上的最内部数组里取最大值的下标,一共取了6个值。

在实际应用中,tf.argmax()往往和tf.equal()在tensorflow的模型中一起使用,用来计算模型的准确度。

如:

correct_pred = tf.equal(tf.argmax(pred,1), tf.argmax(y,1))
accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))

猜你喜欢

转载自blog.csdn.net/zgcr654321/article/details/83652272