tensorflow: 获取tensor维度

假设现在有一个tensor named tensor_a:
如果a是一个数组或其他类型,使用以下函数将a转换为tensor:
tensor_a = tf.convert_to_tensor(a)
tensor_a 的dim获取方法:
shape_a = tensor_a.get_shape()
dim_a = len(shape_a)
tensor_a各维度大小可由list列表获取:
tensor_a. get_shape ( ) . as_list ( )

假设tensor_a的dim为4,则其各维度大小分别为:
tensor_a.get_shape()[0].value
tensor_a.get_shape()[1].value
tensor_a.get_shape()[2].value

tensor_a.get_shape()[4].value

猜你喜欢

转载自blog.csdn.net/u010454261/article/details/80407111