Deep learning和tensorflow学习记录(二十八):static and dynamic shapes

tensorflow中tensor的static shape是在graph构建的时候就已经确定了。

import tensorflow as tf

a = tf.placeholder(tf.float32, [None, 128])
static_shape = a.shape.as_list()
dynamic_shape = tf.shape(a)

print(static_shape)
print(dynamic_shape)

None表示不确定,可以为任何尺寸,可以在Session.run()的时候动态确定。tf.shape()获取dynamic shape,返回的是一个tensor。

输出:

[None, 128]
Tensor("Shape:0", shape=(2,), dtype=int32)

可通过set_shape()设置tensor的shape,

a.set_shape([32, 128])
static_shape = a.shape.as_list()
dynamic_shape = tf.shape(a)
print(static_shape)
print(dynamic_shape)

输出:

[32, 128]
Tensor("Shape:0", shape=(2,), dtype=int32)

也可以通过tf.reshape()重设shape,

a = tf.placeholder(tf.float32, [None, 128])
a =  tf.reshape(a, [32, 128])
static_shape = a.shape.as_list()
dynamic_shape = tf.shape(a)
print(static_shape)
print(dynamic_shape)

输出:

[32, 128]
Tensor("Shape:0", shape=(2,), dtype=int32)

可以定义一个函数,当static shape为None时返回dynamic shape,否则返回static shape。

def get_shape(tensor):
  static_shape = tensor.shape.as_list()
  dynamic_shape = tf.unstack(tf.shape(tensor))
  dims = [s[1] if s[0] is None else s[0]
          for s in zip(static_shape, dynamic_shape)]
  return dims

b = tf.placeholder(tf.float32, [None, 10, 32])
shape = get_shape(b)
print(shape)
b = tf.reshape(b, [shape[0], shape[1] * shape[2]])
shape = get_shape(b)
print(shape)

输出:

[<tf.Tensor 'unstack:0' shape=() dtype=int32>, 10, 32]
[<tf.Tensor 'unstack_1:0' shape=() dtype=int32>, 320]

猜你喜欢

转载自blog.csdn.net/heiheiya/article/details/81084179