TensorFlow的tf.concat实例详细介绍

版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/sinat_29957455/article/details/86100641

tf.concat函数:函数功能比较简单,主要用于连接两个数组
参数

  • values:需要连接的数组
  • axis:从哪个维度来连接数组

例子

  1. 一维数组
import tensorflow as tf

if __name__ == "__main__":
    a = [1,2,3]
    b = [4,5,6]
    c = tf.concat([a,b],0)
    sess = tf.InteractiveSession()
    print(sess.run(c)) #[1 2 3 4 5 6]

注意:axis参数不能超过数组的维度。如果超过数组的维度,如下:

    c = tf.concat([a,b],1)

则会报,ValueError: Shape must be at least rank 2 but is rank 1 for 'concat',意思是数组至少是二维,axis才能为1。

  1. 二维数组
    a = [[1,1],[2,2],[3,3]]
    b = [[4,4],[5,5],[6,6]]
    c = tf.concat([a,b],0)
    print(sess.run(c))
    """
    [[1 1]
     [2 2]
     [3 3]
     [4 4]
     [5 5]
     [6 6]]
    """
    c = tf.concat([a,b],1) #等价于tf.concat([a,b],-1)
    print(sess.run(c))
    """
    [[1 1 4 4]
     [2 2 5 5]
     [3 3 6 6]]
    """
  1. 三维数组
    a = [[[1,1],[2,2]],[[3,3],[4,4]]]
    b = [[[5,5]],[[6,6]]]

    c = tf.concat([a,b],1)
    print(sess.run(c))
    """
    [[[1 1]
      [2 2]
      [5 5]]

     [[3 3]
      [4 4]
      [6 6]]]
    """

注意:在使用tf.concat函数连接两个数组的时候,数组该维度必须是一致的,否则会报错,如下:

    c = tf.concat([a,b],0)

错误提示ValueError: Dimension 0 in both shapes must be equal, but are 2 and 1,意思是a在第1个维度上shape是2,而b在第一个维度上shape是1。
总结:如何来判断数组是否在该个维度上的shape是相同的呢?其实很简单,我们根据tf.concat的axis参数来去数组的[],0表示去掉最外面的一层,1去掉两层,以此类推,下面举例说明一下。
如:最后一个例子中的c = tf.concat([a,b],1),我们先将a去掉最外面两层[],变成了[1,1],[2,2]和[3,3],[4,4]],然后再将b去掉最外面两层[],变成了[5,5]和[6,6],此时再进行concat,可以发现此时的shape是相等的。

猜你喜欢

转载自blog.csdn.net/sinat_29957455/article/details/86100641
今日推荐