tensorflow 获得tensor的维度信息,tf.shape()与 a.get_shape()的比较

转自:不知道哪里yychenxie21

相同点:都可以得到tensor a的尺寸
不同点:tf.shape()中a 数据的类型可以是tensor, list, array
a.get_shape()中a的数据类型只能是tensor,且返回的是一个元组(tuple)
注意到tf.shape(a)返回的是一个OP需要再sess.run(),而a.get_shape()得到一个实际的元组。


  
  
  1. import tensorflow as tf
  2. import numpy as np
  3. x=tf.constant([[ 1, 2, 3],[ 4, 5, 6]]
  4. y=[[ 1, 2, 3],[ 4, 5, 6]]
  5. z=np.arange( 24).reshape([ 2, 3, 4]))
  6. sess=tf.Session()
  7. # tf.shape()
  8. x_shape=tf.shape(x) # x_shape 是一个tensor
  9. y_shape=tf.shape(y) # <tf.Tensor 'Shape_2:0' shape=(2,) dtype=int32>
  10. z_shape=tf.shape(z) # <tf.Tensor 'Shape_5:0' shape=(3,) dtype=int32>
  11. print sess.run(x_shape) # 结果:[2 3]
  12. print sess.run(y_shape) # 结果:[2 3]
  13. print sess.run(z_shape) # 结果:[2 3 4]
  14. #a.get_shape()
  15. x_shape=x.get_shape() # 返回的是TensorShape([Dimension(2), Dimension(3)]),不能使用 sess.run() 因为返回的不是tensor 或string,而是元组
  16. x_shape=x.get_shape().as_list() # 可以使用 as_list()得到具体的尺寸,x_shape=[2 3]
  17. y_shape=y.get_shape() # AttributeError: 'list' object has no attribute 'get_shape'
  18. z_shape=z.get_shape() # AttributeError: 'numpy.ndarray' object has no attribute 'get_shape'
import tensorflow as tf
input_tensor=tf.get_variable(name="input",shape=[1,5,5,1],initializer=tf.truncated_normal_initializer(stddev=0.1))
input_tensor_size=tf.shape(input_tensor)
print("直接运行tf.shape() :",input_tensor_size)    #输出是tf.Tensor 'Shape:0' shape=(4,) dtype=int32
print("运行tensor.get_shape()  :",input_tensor.get_shape()) #输出是TensorShape([Dimension(1), Dimension(5), Dimension(5), Dimension(1)])

with tf.Session() as sess:
    init = tf.global_variables_initializer()
    sess.run(init)
    print("在sess.run()中运行tf.shape():",sess.run(input_tensor_size)) #输出是array([1, 5, 5, 1], dtype=int32)
    print("===========================")

猜你喜欢

转载自blog.csdn.net/taolusi/article/details/81228256