keras.layers.Input和tf.placeholder中的None的比较;keras.layers.Reshape和tf.reshape比较

当程序是:

import tensorflow as tf
import numpy as np
import keras

x = np.ones([2, 2, 36])
inputs = keras.layers.Input(shape=(None, None, 36)) # 两个None表示3个维度
outputs = keras.layers.Reshape((-1, 4))(inputs)

with tf.Session() as sess:
    init = tf.global_variables_initializer()
    sess.run(init)
    y = sess.run(outputs, feed_dict={inputs: x})

    print('y: ',np.shape(y))

结果:

Traceback (most recent call last):
  File "keras.layers.Concatenate机制.py", line 16, in <module>
    y = sess.run(outputs, feed_dict={inputs: x})
  File "C:\ProgramData\Anaconda3\lib\site-packages\tensorflow\python\client\session.py", line 877, in run
    run_metadata_ptr)
  File "C:\ProgramData\Anaconda3\lib\site-packages\tensorflow\python\client\session.py", line 1076, in _run
    str(subfeed_t.get_shape())))
ValueError: Cannot feed value of shape (2, 2, 36) for Tensor 'input_1:0', which has shape '(?, ?, ?, 36)'

显然这里两个None表示的是3个维度

再来看:

import tensorflow as tf
import numpy as np
import keras

x = np.ones([2, 36])
inputs = keras.layers.Input(shape=(None, 36))
outputs = keras.layers.Reshape((-1, 4))(inputs)

with tf.Session() as sess:
    init = tf.global_variables_initializer()
    sess.run(init)
    y = sess.run(outputs, feed_dict={inputs: x})

    print('y: ',np.shape(y))

结果为:

Traceback (most recent call last):
  File "keras.layers.Concatenate机制_1.py", line 12, in <module>
    y = sess.run(outputs, feed_dict={inputs: x})
  File "C:\ProgramData\Anaconda3\lib\site-packages\tensorflow\python\client\session.py", line 877, in run
    run_metadata_ptr)
  File "C:\ProgramData\Anaconda3\lib\site-packages\tensorflow\python\client\session.py", line 1076, in _run
    str(subfeed_t.get_shape())))
ValueError: Cannot feed value of shape (2, 36) for Tensor 'input_1:0', which has shape '(?, ?, 36)'

这里一个None表示两个维度。

实际上之后无论是几个None,所表示的维度总比None数多一个。比如keras.layers.Input(shape=(None, None, None, 36))表示的就是shape为'(?, ?, ?, ?, 36)'的输入

这点上tf.placeholder就和keras.layers.Input不一样

看代码:

import tensorflow as tf
import numpy as np
import keras

x = np.ones([2, 36])
inputs = tf.placeholder(tf.int32, [None, 36])
outputs = keras.layers.Reshape((-1, 4))(inputs)

with tf.Session() as sess:
    init = tf.global_variables_initializer()
    sess.run(init)
    y = sess.run(outputs, feed_dict={inputs: x})

    print('y: ',np.shape(y))

结果为

y:  (2, 9, 4)

这说明tf.placeholder中shape的None只表示一个维度,所以[None, 36]就表示输入是2维的。
再就是keras.layers.Reshape总是把0维给保留下来,只针对0维以外的进行reshape
再比如:

import tensorflow as tf
import numpy as np
import keras

x = np.ones([2, 2, 2, 2, 36])
inputs = keras.layers.Input(shape=(None, None, None, 36))
outputs = keras.layers.Reshape((-1, 4))(inputs)

with tf.Session() as sess:
    init = tf.global_variables_initializer()
    sess.run(init)
    y = sess.run(outputs, feed_dict={inputs: x})
    print('y: ',np.shape(y))

'''
结果:
y:  (2, 72, 4)
'''

再比如

import tensorflow as tf
import numpy as np
import keras

x = np.ones([2, 2, 2, 2, 36])
inputs = keras.layers.Input(shape=(2, None, None, 36))
outputs = keras.layers.Reshape((-1, 4))(inputs)

with tf.Session() as sess:
    init = tf.global_variables_initializer()
    sess.run(init)
    y = sess.run(outputs, feed_dict={inputs: x})
    print('y: ',np.shape(y))

'''
结果:
y:  (2, 72, 4)
'''

再如:

import tensorflow as tf
import numpy as np
import keras

x = np.ones([2, 2, 2, 2, 36])
inputs = keras.layers.Input(shape=(None, None, None, 36))
outputs = keras.layers.Reshape((4, -1, 4))(inputs)

with tf.Session() as sess:
    init = tf.global_variables_initializer()
    sess.run(init)
    y = sess.run(outputs, feed_dict={inputs: x})
    print('y: ',np.shape(y))

'''
结果:
y:  (2, 4, 18, 4)
'''

当用tf.reshape代替keras.layers.Reshape时:

import tensorflow as tf
import numpy as np
import keras

x = np.ones([2, 36])
inputs = tf.placeholder(tf.int32, [None, 36])
outputs = tf.reshape(inputs, (-1, 4))

with tf.Session() as sess:
    init = tf.global_variables_initializer()
    sess.run(init)
    y = sess.run(outputs, feed_dict={inputs: x})

    print('y: ',np.shape(y))

'''
结果:
y:  (18, 4)
'''

所以tf.reshape和keras.layers.Reshape不一样,它针对所有的维度进行reshape

猜你喜欢

转载自blog.csdn.net/weixin_43331915/article/details/83305263