TensorFlow2.0中Tensor的合并与拆分

合并

将多个张量在某个维度上合并为一个张量,可以使用拼接和堆叠操作实现,拼接操作并不会产生新的维度,仅在现有的维度上合并,而堆叠会创建新维度。

拼接: tf.concat(tensors, axis)

tensors:保存了所有需要合并的张量 List
axis:指定需要合并的维度索引
合并时输入张量的维数必须匹配,并且除 axis 外的所有维数必须相等

举个例子:

>>> import tensorflow as tf
>>> a = tf.constant([[1, 2, 3], [4, 5, 6]])
>>> b = tf.constant([[7, 8, 9], [10, 11, 12]])
>>> c1 = tf.concat([a, b], 0)
>>> c1
<tf.Tensor: id=3, shape=(4, 3), dtype=int32, numpy=
array([[ 1,  2,  3],
       [ 4,  5,  6],
       [ 7,  8,  9],
       [10, 11, 12]])>
>>> c2 = tf.concat([a, b], 1)
>>> c2
<tf.Tensor: id=7, shape=(2, 6), dtype=int32, numpy=
array([[ 1,  2,  3,  7,  8,  9],
       [ 4,  5,  6, 10, 11, 12]])>

c1、c2 是 a、b 在不同维度下的拼接,未产生新维度

堆叠: tf.stack(tensors, axis)

tensors:保存了所有需要合并的张量 List
axis:指定新维度插入的位置
合并时需要所有待合并的张量 shape 完全一致,即维度和尺寸相同

直接用上面的 a、b 数据:

>>> z1 = tf.stack([a, b])	# axis默认为 0
>>> z1
<tf.Tensor: id=12, shape=(2, 2, 3), dtype=int32, numpy=
array([[[ 1,  2,  3],
        [ 4,  5,  6]],

       [[ 7,  8,  9],
        [10, 11, 12]]])>
>>> z2 = tf.stack([a, b],axis=-1)
>>> z2
<tf.Tensor: id=13, shape=(2, 3, 2), dtype=int32, numpy=
array([[[ 1,  7],
        [ 2,  8],
        [ 3,  9]],

       [[ 4, 10],
        [ 5, 11],
        [ 6, 12]]])>

分割: tf.split(tensor, num_or_size_splits, axis)

tensor:待分割张量
num_or_size_splits:切割方案。当 num_or_size_splits 为单个数值时,如 10,表
示等长切割为 10 份;当 num_or_size_splits 为 List 时,List 的每个元素表示每份的长
度,如[2,4,2,2]表示切割为 4 份,每份的长度依次是 2、4、2、2
axis:指定分割的维度索引号

举个例子:

>>> x = tf.random.normal([10,35,8])
>>> result1 = tf.split(x, num_or_size_splits=10, axis=0)
>>> len(result1)
10
>>> result1[0].shape
TensorShape([1, 35, 8])
>>> result2 = tf.split(x, num_or_size_splits=[4,2,2,2], axis=0)
>>> len(result2)
4
>>> result2[1].shape
TensorShape([2, 35, 8])

猜你喜欢

转载自blog.csdn.net/weixin_44613063/article/details/104073542