TensorFlow中张量转置操作tf.expand_dims用法详解

一、环境

TensorFlow API r1.12

CUDA 9.2 V9.2.148

cudnn64_7.dll

Python 3.6.3

Windows 10

二、官方说明

给输入张量的形状增加1个维度

https://www.tensorflow.org/api_docs/python/tf/expand_dims

tf.expand_dims(
    input,
    axis=None,
    name=None,
    dim=None
)

输入:

(1)input:输入张量

(2)axis:标量,指定在哪个维度上给输入张量增加一个维度,范围必须在 [-输入张量的秩,+输入张量的秩]

(3)name:输出结果张量的名称

(4)dim:标量,等同于axis,被弃用

返回结果:

(1)比输入张量多1维但是包含相同数据的张量

三、实例

(1)tf.expand_dims(input_tensor,0)

>>> import tensorflow as tf
>>> input = tf.constant([1,2], shape=[2])
>>> sess = tf.Session()
>>> sess.run(tf.shape(input))
array([2])
>>> sess.run(tf.shape(tf.expand_dims(input,0)))
array([1, 2])
>>> sess.run(tf.expand_dims(input,0))
array([[1, 2]])
>>> sess.close()

(2)tf.expand_dims(input_tensor,1)

>>> import tensorflow as tf
input = tf.constant([1,2], shape=[2])
>>> sess = tf.Session()
>>> sess.run(tf.shape(input))
array([1])
>>> sess.run(tf.shape(tf.expand_dims(input,1)))
array([2, 1])
>>> sess.run(tf.expand_dims(input,1))
array([[1],
       [2]])

(3)tf.expand_dims(input_tensor,-1)

>>> import tensorflow as tf
>>> input = tf.constant([1,2], shape=[2])
>>> sess = tf.Session()
>>> sess.run(tf.shape(input))
array([1])
>>> sess.run(tf.shape(tf.expand_dims(input,-1)))
array([2, 1])
>>> sess.run(tf.expand_dims(input,-1))
array([[1],
       [2]])

(4)多维拓展0

>>> import tensorflow as tf
>>> input = tf.constant([1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30], shape=[2,3,5])
>>> sess = tf.Session()
>>> sess.run(tf.shape(input))
array([2, 3, 5])
>>> sess.run(tf.shape(tf.expand_dims(input,0)))
array([1, 2, 3, 5])
>>> sess.run(tf.expand_dims(input,0))
array([[[[ 1,  2,  3,  4,  5],
         [ 6,  7,  8,  9, 10],
         [11, 12, 13, 14, 15]],

        [[16, 17, 18, 19, 20],
         [21, 22, 23, 24, 25],
         [26, 27, 28, 29, 30]]]])
>>> sess.close()

(5)多维拓展2

>>> import tensorflow as tf
>>> input = tf.constant([1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30], shape=[2,3,5])
>>> sess = tf.Session()
>>> sess.run(tf.shape(input))
array([2, 3, 5])
>>> sess.run(tf.shape(tf.expand_dims(input,2)))
array([2, 3, 1, 5])
>>> sess.run(tf.expand_dims(input,2))
array([[[[ 1,  2,  3,  4,  5]],

        [[ 6,  7,  8,  9, 10]],

        [[11, 12, 13, 14, 15]]],


       [[[16, 17, 18, 19, 20]],

        [[21, 22, 23, 24, 25]],

        [[26, 27, 28, 29, 30]]]])

(6)多维拓展3

>>> import tensorflow as tf
>>> input = tf.constant([1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30], shape=[2,3,5])
>>> sess = tf.Session()
>>> sess.run(tf.shape(input))
array([2, 3, 5])
>>> sess.run(tf.shape(tf.expand_dims(input,3)))
array([2, 3, 5, 1])
>>> sess.run(tf.expand_dims(input,3))
array([[[[ 1],
         [ 2],
         [ 3],
         [ 4],
         [ 5]],

        [[ 6],
         [ 7],
         [ 8],
         [ 9],
         [10]],

        [[11],
         [12],
         [13],
         [14],
         [15]]],


       [[[16],
         [17],
         [18],
         [19],
         [20]],

        [[21],
         [22],
         [23],
         [24],
         [25]],

        [[26],
         [27],
         [28],
         [29],
         [30]]]])

猜你喜欢

转载自blog.csdn.net/sdnuwjw/article/details/85044994