tf.squeeze()用于压缩张量中为1的轴

tf.squeeze()用于压缩张量中为1的轴

squeeze(
    input,
    axis=None,
    name=None,
    squeeze_dims=None
)

该函数会除去张量中形状为1的轴。

例子:

import tensorflow as tf

raw = tf.Variable(tf.random_normal(shape=(1, 3, 2)))
squeezed = tf.squeeze(raw)

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    print(raw.shape)
    print('-----------------------------')
    print(sess.run(squeezed).shape)

输出如:

(1, 3, 2)
-----------------------------
(3, 2)

猜你喜欢

转载自blog.csdn.net/loseinvain/article/details/78994695
今日推荐