TensorFlow中张量连接操作tf.concat用法详解

一、环境

TensorFlow API r1.12

CUDA 9.2 V9.2.148

Python 3.6.3

二、官方说明

按指定轴(axis)进行张量连接操作(Concatenates Tensors)

tf.concat(
    values,
    axis,
    name='concat'
)

输入:

(1)values:多个张量组成的列表或者单个张量

(2)axis:0维整形张量(整数),定义按照按个数据轴进行张量连接操作,其范围是[-输入张量的阶,+输入张量的阶]。[0,输入张量的阶]范围内的正数表示按照指定的axis轴进行连接操作,在[-输入张量的阶,0]之间的负数表示按照指定的(axis + 输入张量的阶)的轴进行连接操作

(3)name:可选参数,定义该张量连接操作的名称

输出:

输入张量按照指定轴连接后的一个结果张量

三、实例

(1)单个张量作为输入

>>> t1 = [[1,2,3],[4,5,6]]
>>> con1 = tf.concat(t1,0)
>>> shape1 = tf.shape(con1)
>>> with tf.Session() as sess:
...     print(sess.run(con1))
...     print(sess.run(shape1))
... 
[1 2 3 4 5 6]
[6]

(2)多个张量组成的列表作为输入

按照0轴(行)进行连接:

>>> t1 = [[1,2,3],[4,5,6]]
>>> t2 = [[7,8,9],[10,11,12]]
>>> con2 = tf.concat([t1,t2],0)
>>> shape2 = tf.shape(con2)
>>> with tf.Session() as sess:
...     print(sess.run(con2))
...     print(sess.run(shape2))
... 
[[ 1  2  3]
 [ 4  5  6]
 [ 7  8  9]
 [10 11 12]]
[4 3]

按照1轴(列)进行连接:

>>> t1 = [[1,2,3],[4,5,6]]
>>> t2 = [[7,8,9],[10,11,12]]
>>> con3 = tf.concat([t1,t2],1)
>>> shape3 = tf.shape(con3)
>>> with tf.Session() as sess:
...     print(sess.run(con3))
...     print(sess.run(shape3))
... 
[[ 1  2  3  7  8  9]
 [ 4  5  6 10 11 12]]
[2 6]
>>> 

按照-1轴(列)进行连接:

>>> t1 = [[1,2,3],[4,5,6]]
>>> t2 = [[7,8,9],[10,11,12]]
>>> con4 = tf.concat([t1,t2],-1)
>>> shape4 = tf.shape(con4)
>>> with tf.Session() as sess:
...     print(sess.run(con4))
...     print(sess.run(shape4))
... 
[[ 1  2  3  7  8  9]
 [ 4  5  6 10 11 12]]
[2 6]

注意:如果想沿着一个新轴连接张量,则考虑使用stcak

不建议使用:tf.concat([tf.expand_dims(t, axis) for t in tensors],axis)

推荐使用:tf.stack(tensors, axis=axis)

猜你喜欢

转载自blog.csdn.net/sdnuwjw/article/details/84960377