numpy数组中reshape和squeeze函数的使用

参考了:http://blog.csdn.net/zenghaitao0128/article/details/78512715,作了一些自己的补充。

numpy中的reshape函数和squeeze函数是深度学习代码编写中经常使用的函数,需要深入的理解。

其中,reshape函数用于调整数组的轴和维度,而squeeze函数的用法如下,

语法:numpy.squeeze(a,axis = None)

 1)a表示输入的数组;
 2)axis用于指定需要删除的维度,但是指定的维度必须为单维度,否则将会报错;
 3)axis的取值可为None 或 int 或 tuple of ints, 可选。若axis为空,则删除所有单维度的条目;
 4)返回值:数组
 5) 不会修改原数组;
作用:从数组的形状中删除单维度条目,即把shape中为1的维度去掉
 

举例:

numpy的reshape和squeeze函数:

import numpy as np
e = np.arange(10)
print(e)
一维数组:[0 1 2 3 4 5 6 7 8 9]
f = e.reshape(1,1,10)
print(f)

三维数组:(第三个方括号里有十个元素)

[[[0 1 2 3 4 5 6 7 8 9]]],前两维的秩为1

g = f.reshape(1,10,1)
print(g)

三维数组:(第二个方括号里有十个元素)

[[[0]
  [1]
  [2]
  [3]
  [4]
  [5]
  [6]
  [7]
  [8]
  [9]]]
h = e.reshape(10,1,1)
print(h)
三维数组:(第一个方括号里有10个元素)
[[[0]]

 [[1]]

 [[2]]

 [[3]]

 [[4]]

 [[5]]

 [[6]]

 [[7]]

 [[8]]

 [[9]]]

利用squeeze可以把数组中的1维度去掉(从0开始指定轴),以下为不加参数axis,去掉所有1维的轴:

m = np.squeeze(h)
print(m)

以下指定去掉第几轴

n = np.squeeze(h,2)
print(n)
去掉第三轴,变成二维数组,维度为(10,1):
[[0]
 [1]
 [2]
 [3]
 [4]
 [5]
 [6]
 [7]
 [8]
 [9]]

再举一个例子:

p = np.squeeze(g,2)
print(p)

去掉第2轴,得到二维数组,维度为(1,10):

[[0 1 2 3 4 5 6 7 8 9]]
p = np.squeeze(g,0)
print(p)

去掉第0轴,得到二维数组,维度为(10,1):

[[0]
 [1]
 [2]
 [3]
 [4]
 [5]
 [6]
 [7]
 [8]
 [9]]

在matplotlib画图中,非单维的数组在plot时候会出现问题,(1,nx)不行,但是(nx, )可以,(nx,1)也可以。

如下:

import matplotlib.pyplot as plt
squares =np.array([[1,4,9,16,25]]) 
print(squares.shape)    

square的维度为(1,5),无法画图:

做如下修改:

plt.plot(np.squeeze(squares))    
plt.show()

square的维度为(5,),可以画图:

或者做如下修改

squares1 = squares.reshape(5,1)
plt.plot(squares1)  
plt.show()

square的维度为(5,1),可以画图:

猜你喜欢

转载自blog.csdn.net/u011816283/article/details/83445781