详解Python-Numpy库函数take_along_axis()【由索引生成新数组的一系列函数中的其中一个函数】

函数take_along_axis()用于由索引数组生成新的矩阵。

要想完全理解它是比较难的,昊虹君从昨天下午想到现在,也感觉只理解了其一半,不过有一半理解总比没有好吧,下面就把昊虹君的理解给大家分享分享。

Begin…
提问:由已有数组的索引生成新的数组为什么要用函数take_along_axis(),我用Numpy库ndarray对象的切片操作不行么?

答案是:Numpy库ndarray对象的切片操作不是万能的,比如下面的两种情况它就不能解决,而下面两种情况可以用函数take_along_axis()解决。

情况一:我由argsort()函数得到了矩阵元素按从小到大排序的索引,接下来我想由个这个排序索引得到一个新的矩阵,这个新矩阵元素就是按从小到大排序的索引。这种情况下光靠切片操作就很难实现这个功能了。不信的话诸君可以试一试,反正昊虹君是试了的,很麻烦。但是此时用函数take_along_axis()就很方便,示例如下:

import numpy as np

A = np.array([[10, 30, 20], [60, 40, 50]])
B = np.sort(A, axis=1)
index1 = np.argsort(A, axis=1)
C = np.take_along_axis(A, index1, axis=1)

运行结果如下:
在这里插入图片描述
从这个示例可以看出,函数take_along_axis()很方便的帮我们由索引值数组index1按顺序取出了A中的元素形成了数组C。

情况二:
现有三维矩阵A如下:

A = np.arange(2*3*4).reshape([2, 3, 4])

在这里插入图片描述
在这里插入图片描述现在要实现下面这个目标:
选取A[0]的第0行和A[1]的第1行构成一个新的三维矩阵B,B矩阵的形状为(2, 1, 4)。
这个目标用切片操作是无法实现的,昊虹君也尝试过直接用切片实现这个目标,但无奈没有成功。
但是这个目标用函数take_along_axis()就很容易实现了,示例代码如下:

import numpy as np
A = np.arange(2*3*4).reshape([2, 3, 4])

index1 = np.zeros([2, 1, 1]).astype('int')

index1[0, 0, :] = 0
index1[1, 0, :] = 1

B = np.take_along_axis(A, index1, axis=1)

运行结果如下:
在这里插入图片描述

在这里插入图片描述
这几句代码虽然短,但是很不好理解,大家可参考博文 https://blog.csdn.net/baidu_37157624/article/details/123124561
并结合我下面的叙述来理解。
①上面的索引数组index1中每个元素具体的值为原数组中的索引值,其本身的索引值为生成的新矩阵的元素的索引值。

②结合index1的形状(2,1,1)和具体的值,可知index1的作用为:
选取A[0]的第0行和A[1]的第1行构成一个新的三维矩阵B,B矩阵的形状为(2, 1, 4)。

③axis =1 ,代表第1个轴对齐(第1个轴即行维度),第1个轴对齐(行维度对齐)的意思是A的第0个维度和B的第0个维度是一致的,即B的列维度和A的列维度是一致的,即都是四列。所以这里index1的形状的第三个值为1而不是4,因为行维度对齐,所以取一行相当于就取了四列。
当axis=-1或axis=0时,昊虹君对运行的结果还不太能理解,但很多情况下我们都是让axis =1,这符合我们的思维习惯,所以就暂时先理解到axis =1吧。

④index1的维度数应该和A的相同。

⑤二维以下时实现上面的功能是完全可以用切片或take()函数实现的,take()函数的说明见博文https://blog.csdn.net/wenhao_ir/article/details/125714322【搜索关键词“take()”】,但是当维度大于等于三时此时切片或take()函数就难以实现上述目标了。

猜你喜欢

转载自blog.csdn.net/wenhao_ir/article/details/125819211
今日推荐