np.tensordot 的理解和使用

Numpy是使用最广的科学计算库,对于多维数组的操作更是在实践中用的最多,而且也是比较困惑的地方,但是用好了事半功倍,今天讲一下numpy的 tensordot 的使用,这个函数在卷积神经网络的卷积中用到。

数组的基本属性

数组基本属性:维度、形状、strides(跨越数组各个维度所需要经过的字节数)、数组元素个数、元素占用字节数、数组占用空间,用以下例子说明:

>>> X = np.random.randint(0,9,(3,4,5))
>>> X
array([[[5, 1, 3, 6, 5],
        [5, 1, 8, 0, 5],
        [8, 5, 7, 8, 5],
        [8, 1, 5, 1, 4]],

       [[7, 7, 7, 7, 6],
        [0, 3, 4, 4, 6],
        [8, 4, 2, 1, 1],
        [6, 3, 4, 5, 4]],

       [[0, 2, 8, 0, 7],
        [6, 5, 8, 2, 2],
        [0, 1, 2, 3, 5],
        [7, 8, 7, 7, 6]]])
>>> X.ndim
3
>>> X.shape
(3, 4, 5)
>>> X.strides
(160, 40, 8)
>>> X.size
60
>>> X.itemsize
8
>>> X.nbytes
480

多维数组轴向取值

数组的取值看似简单但是在高纬度下,还是需要注意一下取法.
最原始取法,如取第一个元素

>>> X[0][0][0]
5

按轴取值则不同,取出来的值可能是数组,仍以上述为例,X.shape为(3,4,5),说明是3维数组,或者说有三个轴0,1,2. 第0轴上3个元素,第1轴上4个元素,第2轴上5个元素,如果要取轴上元素如何写?看以下例子。以下取第0轴第一个元素。

>>> X[0]
array([[5, 1, 3, 6, 5],
       [5, 1, 8, 0, 5],
       [8, 5, 7, 8, 5],
       [8, 1, 5, 1, 4]])
>>> X[1]
array([[7, 7, 7, 7, 6],
       [0, 3, 4, 4, 6],
       [8, 4, 2, 1, 1],
       [6, 3, 4, 5, 4]])
>>> X[2]
array([[0, 2, 8, 0, 7],
       [6, 5, 8, 2, 2],
       [0, 1, 2, 3, 5],
       [7, 8, 7, 7, 6]])
>>> X[4]
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
IndexError: index 4 is out of bounds for axis 0 with size 3

取1轴上的元素

>>> X[:,0,:]
array([[5, 1, 3, 6, 5],
      [7, 7, 7, 7, 6],
      [0, 2, 8, 0, 7]])
>>> X[:,5,:]
Traceback (most recent call last):
 File "<stdin>", line 1, in <module>
IndexError: index 5 is out of bounds for axis 1 with size 4

可以看到,按轴取出的元素实际上是一个子数组!

Tensordot的使用

进入正题, 运行如下代码:

>>> np.random.seed(10)
>>> A = np.random.randint(0,9,(3,4,5))
>>> B = np.random.randint(0,9,(4,5,2))
>>> np.tensordot(A, B, [(1,2), (0,1)])
array([[233,  89],
       [250, 234],
       [199, 244]])

解释:
(1,2) 是对A而言,不是取第1,2轴,而是除去1,2 轴,所以要取的是第0轴
(0,1) 是对B而言,不是取第0,1轴,而是除去0,1 轴,所以要取的是第2轴

以上两句是精华

A的形状是(3,4,5),第0轴上有3个元素,取法上面讲了;B的形状(4,5,2),第2轴上有2个元素,所以结果形状是(3,2)

Tensordot 的作用就是把取出的子数组做点乘操作,即是 np.sum(a*b) 操作。
我们来验证一下,上述的说法看结果形状(3,2)的第一个元素:A第0轴上第一个元素与B第2轴上的第一个元素点乘。

>>> A[0]
array([[4, 0, 1, 0, 1],
       [8, 0, 8, 6, 4],
       [3, 0, 4, 6, 8],
       [1, 8, 4, 1, 3]])
>>> B[:,:,0]
array([[8, 2, 5, 2, 3],
       [4, 0, 3, 2, 0],
       [0, 0, 1, 0, 5],
       [4, 6, 2, 3, 6]])
>>> np.sum(A[0]*B[:,:,0])
233

结果完全正确!就是这么简单,多说都是废话!

发布了36 篇原创文章 · 获赞 42 · 访问量 2万+

猜你喜欢

转载自blog.csdn.net/weixin_28710515/article/details/90230842