Tensorflow记录:tf.concat

定义

tf.concat(values, concat_dim)
作用是将待操作的张量进行连接/合并
函数中第一个参数是待操作的张量,用[]包起来;第二个参数指定某个维度,

本文的concat函数例子是因为在tf的工程中,静态图的tensor具体维度常常未知,所以需要用tf.palceholder()表示,非一般的已定维度的常量concat。
那种其实很好理解,无非叠加在一起罢了。

例子1

import tensorflow as tf
import numpy as np
i1 = tf.placeholder(tf.float32,[2,None,2])  
i2 = tf.placeholder(tf.float32,[2,None,None])

con_output = tf.concat([i1,i2],axis=0)

num = np.ones([2,2,2])  # 作为真正的输入
# num = tf.random_normal([2,2,2])  # 会报类型错误

with tf.Session() as sess:
    print(sess.run(con_output, feed_dict={i1:num,i2:num}))

在这里插入图片描述
通过shape可以看到最终的shape是(4,?,2),即将i1和i2的第一个维度合并/连接,同时由于第二个维度未定,所以为?;第三个维度因为i1的缘故,为2。

例子2

这里仅修改输入和concat的中的维度

i1 = tf.placeholder(tf.float32,[None,None,2])  
i2 = tf.placeholder(tf.float32,[None,None,None])
num = np.ones([2,2,2])  # 作为真正的输入
con_output = tf.concat([i1,i2],axis=0)

得到的shape是:(?, ?, 2),并不是因为num的第一个维度为2,shape就为4,这里只要i1,i2中的某个维度一同为None,最终的shape也为None

待操作的张量不能是tf.Tensor类型,会报错:
TypeError: The value of a feed cannot be a tf.Tensor object. Acceptable feed values include Python scalars, strings, lists, numpy ndarrays, or TensorHandles.

猜你喜欢

转载自blog.csdn.net/qq_38372240/article/details/104603132
今日推荐