tf.stack讲解

tf.stack

tf.stack(
    values,
    axis=0,
    name='stack'
)
'''
Args:
	values: A list of Tensor objects with the same shape and type.
	axis: An int. The axis to stack along. Defaults to the first dimension. Negative values wrap around, so the valid range is [-(R+1), R+1).
	name: A name for this operation (optional).
Returns:
	output: A stacked Tensor with the same type as values.
'''

参数说明

用于tensor(矩阵)拼接的,输入是一个list(里面的每个元素都是tensor),输出是一个tensor
假设输入是由N个shape为(A,B,C)的tensor组成的一个list
如果沿着axis ==0 进行拼接,那么拼接后的输入的tensor的shape为(N,A,B,C)
如果沿着axis ==1 进行拼接,那么拼接后的输入的tensor的shape为(A,N,B,C)
如果沿着axis ==2 进行拼接,那么拼接后的输入的tensor的shape为(A,B,N,C)
如果沿着axis ==3 进行拼接,那么拼接后的输入的tensor的shape为(A,B,C,N)

请结合上面的例子1

import tensorflow as tf
a=tf.zeros([34])
b=tf.ones([34])
d_0=tf.stack([a,b])
d_1=tf.stack([a,b],axis=1)
d_2=tf.stack([a,b],axis=2)
with tf.Session() as sess:
    print("A:\n",sess.run(a))
    print("B:\n",sess.run(b))
    print("D0:\n",sess.run(d_0))
    print("D1:\n",sess.run(d_1))
    print("D2:\n",sess.run(d_2))

输出

A:
 [[0. 0. 0. 0.]
 [0. 0. 0. 0.]
 [0. 0. 0. 0.]]
B:
 [[1. 1. 1. 1.]
 [1. 1. 1. 1.]
 [1. 1. 1. 1.]]
D0:
 [[[0. 0. 0. 0.]
  [0. 0. 0. 0.]
  [0. 0. 0. 0.]]

 [[1. 1. 1. 1.]
  [1. 1. 1. 1.]
  [1. 1. 1. 1.]]]
D1:
 [[[0. 0. 0. 0.]
  [1. 1. 1. 1.]]

 [[0. 0. 0. 0.]
  [1. 1. 1. 1.]]

 [[0. 0. 0. 0.]
  [1. 1. 1. 1.]]]
D2:
 [[[0. 1.]
  [0. 1.]
  [0. 1.]
  [0. 1.]]

 [[0. 1.]
  [0. 1.]
  [0. 1.]
  [0. 1.]]

 [[0. 1.]
  [0. 1.]
  [0. 1.]
  [0. 1.]]]

请结合上面的例子2

x = tf.constant([1, 4])
y = tf.constant([2, 5])
z = tf.constant([3, 6])
tf.stack([x, y, z])  # [[1, 4], [2, 5], [3, 6]] (Pack along first dim.)
tf.stack([x, y, z], axis=1)  # [[1, 2, 3], [4, 5, 6]]

猜你喜欢

转载自blog.csdn.net/qq_32806793/article/details/85223426