numpy.squeeze()的用法

函数原型:numpy.squeeze(a, axis=None)
函数功能:把数组中shape中为1的维度去掉。默认删除a数组中所有shape中为1的维度,axis指定要删除的维度,axis=0表示第0维,若是该维度的shape不为1,则会报错。

例如:

a = [ [[1,2]], [[3,4]] ]    # shape为2*1*2
# 删除中间为1的维度后
a = [ [1,2], [3,4] ]  # 看起来就像是将“穿”的夹层多余的衣服(括号)脱掉一层

实例:

import numpy as np
x = np.array([[[0], [1], [2]]])
print(x)
"""x=
[[[0]
  [1]
  [2]]]
"""
print(x.shape)  # (1, 3, 1)  两个维度为1的维
# 删除数组中第1个维度为1的维
x1 = np.squeeze(x, axis=0)   # axis=0表示第1维
# 删除数组中第2个维度为1的维
x2 = np.squeeze(x, axis=2)   # axis=2表示第3维
# 默认删除所有维度为1的维
x3 = np.squeeze(x)  # 从数组的形状中删除单维条目,即把shape中为1的维度去掉
# 删除第2维
x4 = np.squeeze(x, axis=1)  # 会报ValueError错,因为数组第2维不为1

print(x1) 
''' [[0]
 [1]
 [2]]'''
print(x1.shape)    # (3, 1)
print(x2)    # [[0 1 2]]
print(x2.shape)    #(1, 3)
print(x3)    # [0 1 2]
print(x3.shape)    # (3,)

参考:numpy.squeeze()的用法
numpy官方文档

猜你喜欢

转载自blog.csdn.net/u011208984/article/details/109309695