tensorflow.squeeze() 函数

tensorflow.squeeze(input, squeeze_dims=None, name=None)

参数: input  -->  输入的tensor

             squeeze_dims = None  -->默认None是删除input中所有大小是1的维度,若指定位置则删除所指定位置大小是1的维度

             name -->名称(可选)

原始数据

y = tf.expand_dims(y,axis=-1)
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    value = sess.run(y)
    print (value)
    print (y.shape)
[[[[[[1]
     [2]
     [3]]]


   [[[4]
     [5]
     [6]]]]]]
(1, 1, 2, 1, 3, 1)

删除所有大小是1的维度:

z = tf.squeeze(y)
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    value = sess.run(z)
    print (value)
    print (z.shape)
    print ("z[0][1]: ",value[0][1])
[[1 2 3]
 [4 5 6]]
(2, 3)
z[0][1]:  2

删除位置是3,5的大小是1的维度(从0起)

z1 = tf.squeeze(y, [3, 5])
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    value = sess.run(z1)
    print (value)
    print (z1.shape)
[[[[1 2 3]
   [4 5 6]]]]
(1, 1, 2, 3)

猜你喜欢

转载自blog.csdn.net/Muzi_Water/article/details/81389248