深度解密numpy中的shape

楔子

不管数据分析,数据挖掘,机器学习等,都会经常遇到数组的 shape 操作。那么你有没有想过,为什么一个一维数组,能 shape 成任意维度的高维数组呢。

NumPy 也提供高效的数据 shape 操作,那么它是如何实现这种操作,用到什么数据结构,原理又是什么?

弄明白这个问题后,再去使用 NumPy、TensorFlow,就会瞬间清晰很多。

解密shape

一个一维数组,长度为 12,为什么能变化为二维 (12,1) 或 (2,6)、三维 (12,1,1) 或 (2,3,2)、四维 (12,1,1,1) 或 (2,3,1,2) 呢?总之,为什么能变化为任意多维度呢

我们下面就来看看,先导入numpy

import numpy as np

# 创建一个数组a,从0开始,间隔为2,包含12个元素
a = np.arange(0, 24, 2)

# 打印一下
print(a)  # [ 0  2  4  6  8 10 12 14 16 18 20 22]

如上数组 a,NumPy 会将其解读成两个结构,一个 buffer,还有一个 view。

buffer 的示意图如下所示:

view 是解释 buffer 的一个结构,比如数据类型,flags 信息等:

print(a.dtype)  # int32
print(a.flags)
"""
  C_CONTIGUOUS : True
  F_CONTIGUOUS : True
  OWNDATA : True
  WRITEABLE : True
  ALIGNED : True
  WRITEBACKIFCOPY : False
  UPDATEIFCOPY : False
"""

使用 a[6] 访问数组 a 中 index 为 6 的元素。从背后实现看,NumPy 会辅助一个轴,轴的取值为 0 到 11。示意图如下所示:

所以,借助这个轴 i,a[6] 就会被索引到元素 12,如下所示:

至此,大家要建立一个轴的概念。接下来,做一次 reshape 变化,变化数组 a 的 shape 为 (2,6):

b = a.reshape((2, 6))
print(b)
"""
[[ 0  2  4  6  8 10]
 [12 14 16 18 20 22]]
"""

此时,NumPy 会建立两个轴,假设为 i、j,i 的取值为 0 到 1,j 的取值为 0 到 5,示意图如下:

使用 b[1][2]或者b[1, 2] 获取元素到 16:

print(b[1, 2])  # 16

两个轴的取值分为 1、2,如下图所示,定位到元素 16:

平时,可能会有人混淆两个 shape,(12,) 和 (12,1),其实前者一个轴,后者两个轴,示意图分别如下。

前者是一个轴,取值从 0 到 11;后者是两个轴,i 轴取值从 0 到 11,j 轴取值从 0 到 0。

至此,大家要建立两个轴的概念。并且,通过上面几幅图看到,无论 shape 如何变化,变化的是视图,底下的 buffer 始终未变。

接下来,上升到三个轴,变化数组 a 的 shape 为 (2,3,2):

c = a.reshape((2, 3, 2))
print(c) 
"""
[[[ 0  2]
  [ 4  6]
  [ 8 10]]

 [[12 14]
  [16 18]
  [20 22]]]
"""

数组 c 有三个轴,取值分别为 0 到 1,0 到 2,0 到 1,示意图如下所示:

注意体会,i、j、k 三个轴,其值的分布规律。如果去掉 i 轴取值为 1 的单元格后:

实际就对应到数组 c 的前半部分元素:

c = a.reshape((2, 3, 2))
print(c[0: 1])
"""
[[[ 0  2]
  [ 4  6]
  [ 8 10]]]
"""

至此,三个轴的 reshape 已经讲完,再说一个有意思的问题。

还记得,原始的一维数组 a 吗?它一共有 12 个元素,后来,我们变化它为数组 c,shape 为 (2,3,2),那么如何升级为 4 维或任意维呢?

4 维可以为:(1,2,3,2),示意图如下:

看到,轴 i 索引取值只有 0,它被称为自由维度,可以任意插入到原数组的任意轴间。比如,5 维可以为:(1,2,1,3,2):

至此,你应该完全理解 reshape 操作后的魔法:

  • buffer 是个一维数组,永远不变;
  • 变化的 shape 通过 view 传达;
  • 取值仅有 0 的轴为自由轴,它能变化出任意维度。

关于 reshape 操作,最后再说一点,reshape 后的数组,仅仅是原来数组的视图 view,并没有发生复制元素的行为,这样才能保证 reshape 操作更为高效。

import numpy as np

a = np.array([1, 2, 3, 4])
b = a.reshape((2, 2))

b[0, 0] = 100
print(a)  # [100   2   3   4]

b是a的一个view,b改变了会影响到a。感觉有点类似golang的切片啊。

在了解完 reshape 操作的奥秘后,相信大家都建立轴和多轴的概念,这对灵活使用高维数组很有帮助。

猜你喜欢

转载自www.cnblogs.com/traditional/p/12510615.html
今日推荐