tf.expand_dims()函数解析

tf.expand_dims(
    input,
    axis=None,
    name=None,
    dim=None
)

其中:
input是输入张量。

axis是指定扩大输入张量形状的维度索引值。

dim等同于轴,一般不推荐使用。

函数的功能是在给定一个input时,在axis轴处给input增加一个维度。

axis:

给定张量输入input,此操作为选择维度索引值,在输入形状的维度索引值的轴处插入1的维度。 维度索引值的轴从零开始; 如果您指定轴是负数,则从最后向后进行计数,也就是倒数。

import tensorflow as tf

# 't' is a tensor of shape [2]
t = tf.constant([1,2])
print(t.shape)
t1 = tf.expand_dims(t, 0)
print(t1.shape)
t2 = tf.expand_dims(t, 1)
print(t2.shape)
t3 = tf.expand_dims(t, 1)
print(t3.shape)

> (2,)
> (1, 2)
> (2, 1)
> (2, 1)
import tensorflow as tf
import numpy as np

# 't2' is a tensor of shape [2, 3, 5]
t2 = np.zeros((2,3,5))
print(t2.shape)
t3 = tf.expand_dims(t2, 0)
t4 = tf.expand_dims(t2, 2)
t5 = tf.expand_dims(t2, 3)
print(t3.shape)
print(t4.shape)
print(t5.shape)

> (2, 3, 5)
> (1, 2, 3, 5)
> (2, 3, 1, 5)
> (2, 3, 5, 1)

猜你喜欢

转载自blog.csdn.net/TeFuirnever/article/details/88797810
今日推荐