numpy的reshape和transpose机制解释

reshape和transpose都是非常高效的算子,究其原因,是因为二者均没有在内存中重新排列数据,只是对数据的shape或strides等信息进行了改变。下面分别简介。

ndarray的base和strides属性

为了更好地理解reshape和transpose算子,需要对ndarray的shape, base, strides三个属性有所了解,其中shape很容易理解,就不多说了,下面简单介绍一下base和strides。

base

base参考:https://numpy.org/doc/stable/reference/generated/numpy.ndarray.base.html

如果一个ndarray是通过其他ndarray经过某种操作创建出来的,那么其base就会指向最初的源头。
比如下面例子中,ab, c的源头,所以b.base 和 c.base 都等于a,而a本身没有base,所以是None。

import numpy as np

a = np.array([0, 1, 2, 3, 4, 5])
print(a)  # ==> [0 1 2 3 4 5]
print(a.base)  # ==> None

b = a.reshape([2, 3])
print(b)  # ==>
# [[0 1 2]
#  [3 4 5]]
print(b.base)  # ==> [0 1 2 3 4 5]

c = b.transpose([1, 0])
print(c)  # ==>
# [[0 3]
#  [1 4]
#  [2 5]]
print(c.base)  # ==> [0 1 2 3 4 5]

strides

strides参考:https://numpy.org/doc/stable/reference/generated/numpy.ndarray.strides.html

ndarray的每一个维度(axis)都有一个strides,表示从数组在某个维度进行遍历的内存偏移量

比如在下面的例子中,数组a三个维度的strides分别是(48, 16, 4),意思是:

  • a[0, 0, 0]a[0, 0, 1] = 1的内存偏移量是4字节,1个int型数字是4字节
  • a[0, 0, 0]a[0, 1, 0] = 4的内存偏移量是16字节,因为需要偏移4个int型数字
  • a[0, 0, 0]a[1, 0, 0] = 12的内存偏移量是48字节,因为需要偏移12个int型数字
import numpy as np

a = np.arange(24).reshape([2, 3, 4])

print(a.strides)  # ==> (48, 16, 4)
print(a)  # ==>
# [[[ 0  1  2  3]
#   [ 4  5  6  7]
#   [ 8  9 10 11]]
#
#  [[12 13 14 15]
#   [16 17 18 19]
#   [20 21 22 23]]]

reshape

reshape仅仅只是改变了数组的shape属性,比如把shape从 ( 4 , ) (4,) (4,)改成 ( 2 , 2 ) (2,2) (2,2)。通过下面的测试代码,可以明白reshape的下列性质:

  • 如果我们从最后一个维度开始,依次向前循环打印数组的话,会发现无论怎么样reshape,数组打印的顺序不会发生任何变化。也就是说无论reshape多少次,数组打印顺序不变。
  • 类似于python的浅拷贝,reshape之后,尽管变量发生了变化,但是变量内的数据体却未被碰过。下面列子中,改变reshape后的b的第一个值,发现所有相关的变量的第一个值都发生了变化,所以就可以知道,经reshape后,变量用于保存数据的那块内存没有被碰过。
import numpy as np

a = np.arange(4)  # a = torch.arange(4)
print(a)  # ==> [0 1 2 3]
print(a.shape)  # ==> (4,)

b = a.reshape([2, 2])  # b = a.reshape([2, 2])
print(b)  # ==> [[0 1], [2 3]]
print(b.shape)  # ==> (2, 2)

c = b.reshape([-1])  # c = torch.reshape(b, [-1])
print(c)  # ==> [0 1 2 3]

b[0, 0] = 100
print(a)  # ==> [100   1   2   3]
print(b)  # ==> [[100 1], [2 3]]
print(c)  # ==> [100   1   2   3]

transpose

transpose改变了数组的维度(axis)排列顺序。比如对于二维数组,如果我们把两个维度的顺序互换,那就是我们很熟悉的矩阵转置。而transpose可以在更多维度的情况下生效。transpose的入参是输出数组的维度排列顺序,序号从0开始计数。

下面例子中我们改变了transpose后的b的第一个元素的值,发现a也随之改变,说明transpose也没有去碰数组的内存。那么问题来了,既然数组没有在内存中重新排列,那么打印顺序是受什么影响而发生了改变呢?是strides。

import numpy as np

a = np.arange(24).reshape([2, 3, 4])
print(a.base)  # ==>
# [ 0  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15 16 17 18 19 20 21 22 23]

print(a.shape)  # ==> (2, 3, 4)
print(a.strides)  # ==> (48, 16, 4)
print(a)  # ==>
# [[[ 0  1  2  3]
#   [ 4  5  6  7]
#   [ 8  9 10 11]]
#
#  [[12 13 14 15]
#   [16 17 18 19]
#   [20 21 22 23]]]

b = a.transpose([1, 2, 0])
print(b.shape)  # ==>(3, 4, 2)
print(b.strides)  # ==> (16, 4, 48)
print(b)  # ==>
# [[[ 0 12]
#   [ 1 13]
#   [ 2 14]
#   [ 3 15]]
#
#  [[ 4 16]
#   [ 5 17]
#   [ 6 18]
#   [ 7 19]]
#
#  [[ 8 20]
#   [ 9 21]
#   [10 22]
#   [11 23]]]


b[0, 0, 0] = 100
print(a)  # ==>
# [[[100   1   2   3]
#   [  4   5   6   7]
#   ...]]]

print(b)  # ==>
# [[[100  12]
#   [  1  13]
#   ...]]]

下面图示一下strides的含义。

在这里插入图片描述

首先明确一个很重要的概念,strides都是相对于base数组而言进行遍历的,所以无论是a还是b,遍历时需要参考的源头都是a.base / b.base,也就是最上面的一维数组。

数组a的strides情况我们前面已经讲过了,接下来主要看看b

  • 由于b.strides最后一个维度的值是48,所以b[0, 0, 1]b[0, 0, 0]b.base中偏移48字节后的数字,也就是12
  • b.strides中间维度的值是4,所以b[0, 1, 0]b[0, 0, 0]b.base中偏移4字节后的数字,也就是1
  • b.strides第一个维度的值是16,所以b[1, 0, 0]b[0, 0, 0]b.base中偏移16字节后的数字,也就是4

所以ranspose操作只是改变了strides的顺序,没有重新排列内存中的数据。

总结

前面我们在解释reshape和transpose的机制时,分别从ndarray的shape和strides属性进行了侧重解释。实际上reshape既改变shape也改变strides,而transpose也可能会改变shape。

但这两个算子均不会在内存中重新排列数据。

猜你喜欢

转载自blog.csdn.net/bby1987/article/details/113729250
今日推荐