Tensorflow - tf.expand_dims 学习

API: https://tensorflow.google.cn/api_docs/python/tf/expand_dims?hl=zh-cn 

tf.expand_dims(
    input,
    axis=None,
    name=None,
    dim=None
)

在input的axis位置插入一维的张量

这个操作在input的维度中索引为axis的位置插入一维张量。维度索引axis从零开始; 如果指定负数,axis则从末尾向后计数

例子:

# 't' is a tensor of shape [2]
tf.shape(tf.expand_dims(t, 0))  # [1, 2]
tf.shape(tf.expand_dims(t, 1))  # [2, 1]
tf.shape(tf.expand_dims(t, -1))  # [2, 1]

# 't2' is a tensor of shape [2, 3, 5]
tf.shape(tf.expand_dims(t2, 0))  # [1, 2, 3, 5]
tf.shape(tf.expand_dims(t2, 2))  # [2, 3, 1, 5]
tf.shape(tf.expand_dims(t2, 3))  # [2, 3, 5, 1]

 这个操作在一个batch里面插入一个一维元素是很好用的

猜你喜欢

转载自blog.csdn.net/maka_uir/article/details/83858313