似乎数据类型先转为float32为好,除了complex类型的数据。不然可能报错。
>>> xx.shape
(8, 10, 10)
>>> xx2=tf.constant(xx,tf.float32)
>>> inputs=keras.Input(shape=xx.shape[1:],tensor=xx2)
>>> with tf.Session() as sess:
print(sess.run(inputs))
上面这个是查看最基本的inputs,然而我直接打印BN后的结果出现错误,what's up ?
>>> xx3=keras.layers.BatchNormalization(input_shape=xx.shape[1:])(inputs)
>>> with tf.Session() as sess:
print(sess.run(xx3))
Traceback (most recent call last):
File "D:\python\lib\site-packages\tensorflow_core\python\client\session.py", line 1365, in _do_call
return fn(*args)
File "D:\python\lib\site-packages\tensorflow_core\python\client\session.py", line 1350, in _run_fn
target_list, run_metadata)
File "D:\python\lib\site-packages\tensorflow_core\python\client\session.py", line 1443, in _call_tf_sessionrun
run_metadata)
tensorflo