tf.stack
tf.stack(
values,
axis=0,
name='stack'
)
'''
Args:
values: A list of Tensor objects with the same shape and type.
axis: An int. The axis to stack along. Defaults to the first dimension. Negative values wrap around, so the valid range is [-(R+1), R+1).
name: A name for this operation (optional).
Returns:
output: A stacked Tensor with the same type as values.
'''
参数说明
用于tensor(矩阵)拼接的,输入是一个list(里面的每个元素都是tensor),输出是一个tensor
假设输入是由N个shape为(A,B,C)的tensor组成的一个list
如果沿着axis ==0 进行拼接,那么拼接后的输入的tensor的shape为(N,A,B,C)
如果沿着axis ==1 进行拼接,那么拼接后的输入的tensor的shape为(A,N,B,C)
如果沿着axis ==2 进行拼接,那么拼接后的输入的tensor的shape为(A,B,N,C)
如果沿着axis ==3 进行拼接,那么拼接后的输入的tensor的shape为(A,B,C,N)
请结合上面的例子1
import tensorflow as tf
a=tf.zeros([3,4])
b=tf.ones([3,4])
d_0=tf.stack([a,b])
d_1=tf.stack([a,b],axis=1)
d_2=tf.stack([a,b],axis=2)
with tf.Session() as sess:
print("A:\n",sess.run(a))
print("B:\n",sess.run(b))
print("D0:\n",sess.run(d_0))
print("D1:\n",sess.run(d_1))
print("D2:\n",sess.run(d_2))
输出
A:
[[0. 0. 0. 0.]
[0. 0. 0. 0.]
[0. 0. 0. 0.]]
B:
[[1. 1. 1. 1.]
[1. 1. 1. 1.]
[1. 1. 1. 1.]]
D0:
[[[0. 0. 0. 0.]
[0. 0. 0. 0.]
[0. 0. 0. 0.]]
[[1. 1. 1. 1.]
[1. 1. 1. 1.]
[1. 1. 1. 1.]]]
D1:
[[[0. 0. 0. 0.]
[1. 1. 1. 1.]]
[[0. 0. 0. 0.]
[1. 1. 1. 1.]]
[[0. 0. 0. 0.]
[1. 1. 1. 1.]]]
D2:
[[[0. 1.]
[0. 1.]
[0. 1.]
[0. 1.]]
[[0. 1.]
[0. 1.]
[0. 1.]
[0. 1.]]
[[0. 1.]
[0. 1.]
[0. 1.]
[0. 1.]]]
请结合上面的例子2
x = tf.constant([1, 4])
y = tf.constant([2, 5])
z = tf.constant([3, 6])
tf.stack([x, y, z]) # [[1, 4], [2, 5], [3, 6]] (Pack along first dim.)
tf.stack([x, y, z], axis=1) # [[1, 2, 3], [4, 5, 6]]