关于TensorFlow中tf.reshape()及shape的问题

我看别人都有转载请声明出处,我也写上,:)

转载请申明出处,https://blog.csdn.net/sinat_28704977/article/details/80626689

更重要的,如有错误,请批评指正,不胜感谢。


先上代码:

import numpy as np
import tensorflow as tf
a = tf.constant([
    [ [ 1.0, 2.0, 3.0, 4.0 ],
      [ 5.0, 6.0, 7.0, 8.0 ],
      [ 8.0, 7.0, 6.0, 5.0 ],
      [ 4.0, 3.0, 2.0, 1.0 ] ],
    [ [ 4.0, 3.0, 2.0, 1.0 ],
      [ 8.0, 7.0, 6.0, 5.0 ],
      [ 1.0, 2.0, 3.0, 4.0 ],
      [ 5.0, 6.0, 7.0, 8.0 ] ]
])
c=a
image_shape = c.get_shape()
b=tf.reshape(a,[4,4,2])
a = tf.reshape(a, [ 1, 4, 4, 2 ])
#这里reshape是强制转换,是直接从上面矩阵挨个取值并转换为目标shape,转换后的图片不是如上图一样的两个4*4数组即(2,4,4),而是(1,4,4,2)。

with tf.Session() as sess:
    g,h=(c,image_shape[-1].value)
    d = image_shape[ 1: ].as_list()#[4,4]一维列表
    dim=1
    for test in d:#image_shape为(2,4,4)tensor_shape类型,维数为3
        print(test)
        dim*=int(test)
    e=tf.reshape(a,[-1,dim])
    f=tf.reshape(e,(32,1))
    print(e,'\n',e.get_shape())
    image = sess.run(e)
    image2 = sess.run(f)
    print(image,image2)

问题1

代码中,a的shape为(2,4,4),并不是我们直观认为的图像格式4*4*2。这是需要注意的一点。

问题2

首先,tensor有get_shape()方法获得tensor的shape,类型为tensor_shape,然后,shape[1:]的意思是取从第二个维度开始的shape,例如返回的tensor_shape为(1,4,4,2),加上as_list()就成为[4,4,2]一个一维列表.

问题3

e=tf.reshape(a,[-1,dim])

这里不是转为列向量,因为如果是一个列向量,那么shape应该是(1,32,1),如果是列向量,则是(32,1)。这里转为:最后一个维度是32的一个向量,也就是(1,32)。要区分清楚。
结果图:
这里写图片描述

猜你喜欢

转载自blog.csdn.net/sinat_28704977/article/details/80626689