一、环境
TensorFlow API r1.12
CUDA 9.2 V9.2.148
Python 3.6.3
二、官方说明
按指定轴(axis)进行张量连接操作(Concatenates Tensors)
tf.concat(
values,
axis,
name='concat'
)
输入:
(1)values:多个张量组成的列表或者单个张量
(2)axis:0维整形张量(整数),定义按照按个数据轴进行张量连接操作,其范围是[-输入张量的阶,+输入张量的阶]。[0,输入张量的阶]范围内的正数表示按照指定的axis轴进行连接操作,在[-输入张量的阶,0]之间的负数表示按照指定的(axis + 输入张量的阶)的轴进行连接操作
(3)name:可选参数,定义该张量连接操作的名称
输出:
输入张量按照指定轴连接后的一个结果张量
三、实例
(1)单个张量作为输入
>>> t1 = [[1,2,3],[4,5,6]]
>>> con1 = tf.concat(t1,0)
>>> shape1 = tf.shape(con1)
>>> with tf.Session() as sess:
... print(sess.run(con1))
... print(sess.run(shape1))
...
[1 2 3 4 5 6]
[6]
(2)多个张量组成的列表作为输入
按照0轴(行)进行连接:
>>> t1 = [[1,2,3],[4,5,6]]
>>> t2 = [[7,8,9],[10,11,12]]
>>> con2 = tf.concat([t1,t2],0)
>>> shape2 = tf.shape(con2)
>>> with tf.Session() as sess:
... print(sess.run(con2))
... print(sess.run(shape2))
...
[[ 1 2 3]
[ 4 5 6]
[ 7 8 9]
[10 11 12]]
[4 3]
按照1轴(列)进行连接:
>>> t1 = [[1,2,3],[4,5,6]]
>>> t2 = [[7,8,9],[10,11,12]]
>>> con3 = tf.concat([t1,t2],1)
>>> shape3 = tf.shape(con3)
>>> with tf.Session() as sess:
... print(sess.run(con3))
... print(sess.run(shape3))
...
[[ 1 2 3 7 8 9]
[ 4 5 6 10 11 12]]
[2 6]
>>>
按照-1轴(列)进行连接:
>>> t1 = [[1,2,3],[4,5,6]]
>>> t2 = [[7,8,9],[10,11,12]]
>>> con4 = tf.concat([t1,t2],-1)
>>> shape4 = tf.shape(con4)
>>> with tf.Session() as sess:
... print(sess.run(con4))
... print(sess.run(shape4))
...
[[ 1 2 3 7 8 9]
[ 4 5 6 10 11 12]]
[2 6]
注意:如果想沿着一个新轴连接张量,则考虑使用stcak
不建议使用:tf.concat([tf.expand_dims(t, axis) for t in tensors],axis)
推荐使用:tf.stack(tensors, axis=axis)