tensorflow入门:tensor类

tensor类
print("build a graph")
a = tf.constant([[1,2],[3,4]])
b = tf.constant([[1,1],[0,1]])
print("a:",a)
print("b:",b)
print("type of a:",type(a))
c = tf.matmul(a,b)
print("c:",c)
print("\n")
sess  = tf.Session()  #首字母大写

print("excuted in session")
result_a = sess.run(a)
result_a2 = a.eval(session=sess)
print("result_a:\n",result_a)
print("result_a2:\n",result_a2)

result_b = sess.run(b)
print("result_b:\n",result_b)

result_c = sess.run(c)
print("result_c:\n",result_c)
#输出:
build a graph
a: Tensor("Const:0", shape=(2, 2), dtype=int32)
b: Tensor("Const_1:0", shape=(2, 2), dtype=int32)
type of a: <class 'tensorflow.python.framework.ops.Tensor'>
c: Tensor("MatMul:0", shape=(2, 2), dtype=int32)
excuted in session
result_a:
 [[1 2]
 [3 4]]
result_a:
 [[1 2]
 [3 4]]
result_b:
 [[1 1]
 [0 1]]
result_c:
 [[1 3]
 [3 7]]
形象的介绍一下tensor类

整个程序分为3个过程:
1)建造计算图:用constant()函数弄了两个tensor分别是a和b,然后我们试图直接输出a和b,认为能够输出两个矩阵(至少我们以前的编程经验就是这样),但是输出是各自度对应的tensor类型。通过print(“type of a:”,type(a)) 这句话来输出a的类型,果然是tensor类型(tensor类)。
2) 把a和b这两个tensor传递给tf.matmul()函数:这个函数是用来计算矩阵乘法的函数。返回的依然是tensor用c来接受。到这里为止,印证了之前说了,tensor里面并不负责储存值,想要得到值,得去Session中run。我们可以把这部分看做是创建了一个图,但没有运行这个图。 构造一个Session的对象用来执行图,sess=tf.Session() 。
3) 在session里面执行:可以把一个tensor传递到session.run()里面去,得到其值。等价的也可以用result_a2=a.eval(session=sess) 来得到。那么返回的结果是什么呢?比如result_c是一个什么东西呢?是numpy.ndarray。要是你熟悉numpy的话。

属性:

属性 介绍
device 表示tensor将被产生的设备名称
dtype tensor的元素类型
graph 这个tensor被哪个图所有
name 这个tensor的名称
op 产生这个tensor作为输出的操作(Operation)
shape tensor的形状(返回 的是tf.TensorShape这个表示tensor形状的类)
value_index 表示这个tensor在其操作结果中的索引

函数:

tf.Tensor.sonsumers():

返回消耗这个tensor的操作列表

tf.Tensor.eval(feed_dict=None,session=None):

在一个Session里面评估tensor的值(相当于计算),首先执行之前所有必要的操作来产生个tensor需要的输入,然后通过这些输入产生这个tensor。在激发tensor.eval()这个函数之前,tensor已经投入带session里面,或者一个默认的sesson是最有效的,或者显示指定session参数:
feed_dict:一个字典,用来表示tensor被feed的值
session: 用来计算这个tensor的session,要是没有指定的话,就会使用on.
返回:
表示计算结果值的numpy ndarray

tf.Tensor.get_shape()

返回tensor的形状,类型是TensorShape。这个函数不用把图“投放”到session里面运行就能够得到形状,一般用来debug和得到一些早期的错误信息等等。
例子:

c = tf.constant([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
print(c.get_shape())
# 输出:(2, 3)

d = tf.constant([[1.0, 0.0], [0.0, 1.0], [1.0, 0.0], [0.0, 1.0]])
print(d.get_shape())
# 输出:(4, 2)

# c * d 维度不符合,报错
e = tf.matmul(c,d)
print(e.get_shape())

# c * d 两个矩阵先分别转置,再相乘,transpose_a=True表示转置
f = tf.matmul(c, d, transpose_a=True, transpose_b=True)   
print(f.get_shape())
# 输出:(3, 4)

tf.Tensor.set_shape(shape)

设置更新这个tensor的形状

x1 = tf.placeholder(tf.int32)
print(x1.get_shape())

x1.set_shape([2,2])
print(x1.get_shape())

sess = tf.Session()
print(sess.run(tf.shape(x1), feed_dict={x1:[[0,1],[2,3]]}))
# 如果去掉注释,就会报错,因为我们传入了与图片格式不符合的数据
# #print(sess.run(tf.shape(x1), feed_dict={x1:[0,1,2,3]}))
#输出:
4
(2,2)
[2,2]

reshape

x1 = tf.placeholder(tf.int32)
x2 = tf.reshape(x1, [2,2]) # use tf.reshape()
print(x1.get_shape())

sess = tf.Session()
print(sess.run(tf.shape(x2), feed_dict={x1:[0,1,2,3]}))
print(sess.run(tf.shape(x2), feed_dict={x1:[[0,1],[2,3]]}))
#输出:
#unknown
#[2,2]
#[2,2]

猜你喜欢

转载自blog.csdn.net/acbattle/article/details/80136352