TensorFlow之Tensor对象

Tensor翻译成中文是指张量,tf.Tensor对象是数据对象的句柄。数据对象包括输入的常量和变量,以及计算节点的输出数据对象。所有Python语言中的常见类型的数据需要转为TensorFlow中的Tensor对象后,才能使用TensorFlow框架中的计算节点。

Tensor翻译成中文是指张量,零维度张量表示的是标量,一维张量表示的是向量,二维张量表示的是矩阵。TensorFlow中,训练卷积神经网络模型时,常见的Tensor维度为四维。常见四维Tensor的Shape为[batch, height, width, channels],即一个Batch的输入图片数量,网络层输出特征图的高度、宽度及通道数。

在TensorFlow中,Tensor对象可以存储任意维度的张量,图中参与计算的数据都是Tensor对象, 在前面学习的常量与变量就属于Tensor对象。Tensor对象往往是一个计算操作节点(Operation对象,简写op)的输出,输入其实也可以看成取数据op的输出。Tensor对象的概念比较容易理解,只需将它看成图中的数据即可。

代码:

import tensorflow as tf

data = [[1, 2], [3, 4]]
# 定义变量Tensor
A_tf = tf.Variable(data, name='A')
# 定义常量Tensor
B_tf = tf.constant(data, name='B')

# 根据Tensor的名称获取Tensor
A_tmp = tf.get_default_graph().get_tensor_by_name('A:0')
B_tmp = tf.get_default_graph().get_tensor_by_name('B:0')
# 将查找得到的Tensor对象做矩阵乘法
C_tf = tf.matmul(A_tf, B_tf)

# 将查找到的Tensor对象打印
print("Tensor named 'A:0' :", A_tmp)
print("Tensor named 'B:0' :", B_tmp)
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    A_v, B_v, C_v = sess.run([A_tmp, B_tmp, C_tf])
    print("\n Tensor named 'A:0' value:\n", A_v)
    print("\n Tensor named 'B:0' value:\n", B_v)
    print("\n Matuml output:\n", C_v)

上述代码(TensorFlow版本从原来的2.0.0改成了1.15.0)中,第10-11行根据指定的名称从图中获取对应的Tensor对象;第13行将查找返回的Tensor对象做矩阵乘法;第21-23行将各个Tensor的数据输出打印。上面的代码执行后,输出结果如下:

Tensor named 'A:0' : Tensor("A:0", shape=(2, 2), dtype=int32_ref)
Tensor named 'B:0' : Tensor("B:0", shape=(2, 2), dtype=int32)

 Tensor named 'A:0' value:
 [[1 2]
 [3 4]]

 Tensor named 'B:0' value:
 [[1 2]
 [3 4]]

 Matuml output:
 [[ 7 10]
 [15 22]]
发布了105 篇原创文章 · 获赞 17 · 访问量 11万+

猜你喜欢

转载自blog.csdn.net/qq_38890412/article/details/104085994