目录
经常使用NumPy的小伙伴会遇到axis与keepdims这两个参数,今天笔者来给大家解剖一下。
以NumPy中sum()函数为例,【】里为笔者翻译与理解。
numy.sum()函数定义与说明
函数定义:sum(a, axis=None, dtype=None, out=None, keepdims=np._NoValue, initial=np._NoValue)
作用:Sum of array elements over a given axis. 【指定轴方向上矩阵元素求和。】
Parameters 参数
----------
axis : None or int or tuple of ints, optional 【可选参数,可以为None, 整数或者整数元组】
Axis or axes along which a sum is performed. 【某个轴或多个轴方向上进行求和运算。】
The default, axis=None, will sum all of the elements of the input array.
【默认情况下,axis=None,会计算输入矩阵 中所有元素的和。】
If axis is negative it counts from the last to the first axis.
【如果axis为负数,则从最后一维开始往第一维计算。后续有说明。】
.. versionadded:: 1.7.0【1.7.0版本新增:】
If axis is a tuple of ints, a sum is performed on all of the axes
specified in the tuple instead of a single axis or all the axes as
before.【如果axis为整数元组,会对该元组指定的所有轴方向上元素进行求和。】
keepdims : bool, optional【布尔类型,可选参数。keepdims是keep dimensions的简写】
If this is set to True, the axes which are reduced are left
in the result as dimensions with size one.
【如果这个参数为True,被删去的维度在结果矩阵中就被设置为一。
举例:如果一个2*3*4的三维矩阵,axis=0,keepdims默认为False,则结果矩阵被降维至3*4(二维矩阵);
如果keepdims=True, 则矩阵维度保持不变,还是三维,只是第零个维度由2变为1,即1*3*4的三维矩阵】
With this option, the result will broadcast correctly against the input array.
【有了这个选项,结果矩阵就可以与原 始输入矩阵进行正确的广播运算。】
If the default value is passed, then `keepdims` will not be
passed through to the `sum` method of sub-classes of
`ndarray`, however any non-default value will be.
【如果设定为默认值,那么keepdims不会传递给ndarray子类的sum方法;但是任何非默认值都会传递。】
If the sub-class' method does not implement `keepdims` any
exceptions will be raised.【如果子类方法未实现keepdims,则将引发异常。】
Returns 返回值
-------
sum_along_axis : ndarray【指定轴方向上的和:ndarray类型】
An array with the same shape as `a`, with the specified axis removed.
【与输入参数'a'一样的shape,只是指定轴被移除了。】
If `a` is a 0-d array, or if `axis` is None, a scalar is returned.
【如果'A`是0维矩阵,或者如果`axis`是None,则返回一个标量】
If an output array is specified, a reference to `out` is returned.
【如果输出矩阵有指定,则返回'out'的引用。】
代码示例、说明及输出
#! /usr/bin/env python
# -*- coding: utf-8 -*-
# import the necessary packages
import numpy as np
a1 = np.arange(24).reshape(2, 3, 4) # 创建一个2*3*6矩阵
print("a1:\n", a1)
以2*3*4的三维矩阵为例,它的shape为(2, 3, 4),axis的值是从零开始的,只能是{0, 1, 2}范围内。
0 | 1 | 2 | 3 |
4 | 5 | 6 | 7 |
8 | 9 | 10 | 11 |
12 | 13 | 14 | 15 |
16 | 17 | 18 | 19 |
20 | 21 | 22 | 23 |
下面是笔者的理解方式
axis = i,则矩阵运算沿着第i个下标变化的方向进行操作
axis(正) | 0 | 1 | 2 | 从最外层往里剥开计算 |
shape | 2 | 3 | 4 | |
axis(负) | -3 | -2 | -1 | 从最里层往外剥开计算 |
axis = 0 或axis = -3
print("np.sum(a1, axis=0):\n", np.sum(a1, axis=0))
print("np.sum(a1, axis=-3):\n", np.sum(a1, axis=-3))
当axis=0时,第零轴被移除(removed),原三维矩阵2*3*4会被降维至3*4二维矩阵。
axis(正) | 0 | 1 | 2 |
shape | 3 | 4 | |
axis(负) | -3 | -2 | -1 |
axis = 0:第零轴维数为2,下标范围为{0,1},这个轴上会有两个数同时相加,具体参考如下表格。
注意红色数字下标。
a[0][0][0] + a[1][0][0] |
a[0][0][1] + a[1][0][1] |
a[0][0][2] + a[1][0][2] |
a[0][0][3]+ a[1][0][3] |
a[0][1][0] + a[1][1][0] |
a[0][1][1] + a[1][1][1] |
a[0][1][2] + a[1][1][2] |
a[0][1][3] + a[1][1][3] |
a[0][2][0] + a[1][2][0] |
a[0][2][1] + a[1][2][1] |
a[0][2][2] + a[1][2][2] |
a[0][2][3] + a[1][2][3] |
图片上同样背景色的数字相加求和。
当axis=-3时,求和运算是在倒数第三个轴方向上进行,在此矩阵示例中即为第零轴,二者的计算结果一样。
代码输出结果
np.sum(a1, axis=0):
[[12 14 16 18]
[20 22 24 26]
[28 30 32 34]]
axis = 1 或 axis = -2
print("np.sum(a1, axis=1):\n", np.sum(a1, axis=1))
print("np.sum(a1, axis=-2):\n", np.sum(a1, axis=-2))
当axis=1时,第一轴被移除(removed),原三维矩阵2*3*4会被降维至2*4二维矩阵。
axis(正) | 0 | 1 | 2 |
shape | 2 | 4 | |
axis(负) | -3 | -2 | -1 |
axis = 1:第一轴维数为3,下标范围为{0,1,2},这个轴上会有三个数同时相加,具体参考如下表格。
注意红色数字下标。
a[0][0][0] + a[0][1][0] + a[0][2][0] |
a[0][0][1] + a[0][1][1] + a[0][2][1] |
a[0][0][2] + a[0][1][2] + a[0][2][2] |
a[0][0][3] + a[0][1][3] + a[0][2][3] |
a[1][0][0] + a[1][1][0] + a[1][2][0] |
a[1][0][1] + a[1][1][1] + a[1][2][1] |
a[1][0][2] + a[1][1][2] + a[1][2][2] |
a[1][0][3] + a[1][1][3] + a[1][2][3] |
图片上同样背景色的数字相加求和。
当axis=-2时,求和运算是在倒数第二个轴方向上进行,在此矩阵示例中即为第一轴,二者的计算结果一样。
代码输出结果
np.sum(a1, axis=1):
[[12 15 18 21]
[48 51 54 57]]
axis = 2或axis = -1
print("np.sum(a1, axis=2):\n", np.sum(a1, axis=2))
print("np.sum(a1, axis=-1):\n", np.sum(a1, axis=-1))
当axis=2时,第零轴被移除(removed),原三维矩阵2*3*4会被降维至2*3二维矩阵。
axis(正) | 0 | 1 | 2 |
shape | 2 | 3 | |
axis(负) | -3 | -2 | -1 |
axis = 2:第二轴维数为4,下标范围为{0,1,2, 3},这个轴上会有四个数同时相加,具体参考如下表格。
注意红色数字下标。
a[0][0][0] + a[0][0][1] + a[0][0][2] + a[0][0][3] |
a[0][1][0] + a[0][1][1] + a[0][1][2] + a[0][1][3] |
a[0][2][0] + a[0][2][1] + a[0][2][2] + a[0][2][3] |
a[1][0][0] + a[1][0][1] + a[1][0][2] + a[1][0][3] |
a[1][1][0] + a[1][1][1] + a[1][1][2] + a[1][1][3] |
a[1][2][0] + a[1][2][1] + a[1][2][2] + a[1][2][3] |
图片上同样背景色的数字相加求和。
当axis=-1时,求和运算是在倒数第一个轴方向上进行,在此矩阵示例中即为第二轴,二者的计算结果一样。
代码输出结果
np.sum(a1, axis=2):
[[ 6 22 38]
[54 70 86]]
axis = (0, 2)
print("np.sum(a1, axis=(0,2)):\n", np.sum(a1, axis=(0,2)))
print("np.sum(a1, axis=(-3,-1)):\n", np.sum(a1, axis=(-3,-1))) # axis=(-1,-3) 这样写也可以
当axis=(0, 2)时,第零轴和第二轴都被移除(removed),原矩阵2*3*4会被降维成只有3个元素的一维矩阵,shape为(3,)。
axis(正) | 0 | 1 | 2 |
shape | 3 | ||
axis(负) | -3 | -2 | -1 |
axis=(0, 2):第零轴维数为2,下标范围为{0,1};第二轴维数为4,下标范围为{0,1,2,3}。会有2*4=8个数同时相加。
图片上同样背景色的数字相加求和。
代码输出结果
指定的第零轴和第二轴被移除,输出结果矩阵被降维,变成一维矩阵,里面有三个元素。
np.sum(a1, axis=(0,2)):
[ 60 92 124]
axis = (0, 1)与axis = (1, 2)的情况大家自行实验。
以上所有代码都是在keepdims=False的情况下。下面看一下keepdims=True的例子。
axis = (0, 2),keepdims=True
print("np.sum(a1, axis=(0,2),keepdims=True):\n", np.sum(a1, axis=(0,2),keepdims=True))
【如果keepdims这个参数为True,被删去的维度在结果矩阵中就被设置为一。
举例:如果一个2*3*4的三维矩阵,axis=0,keepdims默认为False,则结果矩阵被降维至3*4(二维矩阵);
如果keepdims=True, 则矩阵维度保持不变,还是三维,只是第零个维度由2变为1,即1*3*4的三维矩阵】
With this option, the result will broadcast correctly against the input array.
【有了这个选项,结果矩阵就可以与原始输入矩阵进行正确的广播运算。】
axis(正) | 0 | 1 | 2 |
shape | 3 | ||
axis(负) | -3 | -2 | -1 |
原理参考上面的笔者翻译,输出结果还是三维矩阵,只是指定的第零轴和第二轴全部变为一维了,即shape为1*3*1。
这样做的目的是方便后续与原始矩阵的广播运算(加减乘除等)。
代码输出结果
np.sum(a1, axis=(0,2),keepdims=True):
[[[ 60]
[ 92]
[124]]]
如果觉得这种理解困难,可以参考其他博客上介绍的方法:
《Python之NumPy(axis=0/1/2...)的透彻理解》
总结与扩展
NumPy中的mean(), average(), all(), any(), concat()等函数中的axis参数与sum()函数原理一样。
带有axis参数的常用函数请参考下面表格。
TensorFlow concat()的axis参数 与 PyTorch中cat()的dim参数也与sum()函数的axis参数一样的原理。
关于numpy.all()与any()可以参考另外一篇博客《Python NumPy.all()与any()函数理解》。
函数名 | 作用 |
mean | 求平均值 |
average | 求加权平均值 |
std | 求标准差 |
var | 求方差 |
min, max | 求最小值,求最大值 |
prod | 计算所有元素的乘积 |
all | 判断所有元素是否都为True |
any | 判断是否有元素为True |
concat | 矩阵拼接 |
sort | 给矩阵排序 |
argsort | 排序后的矩阵中各元素在原始输入矩阵中的索引位置 |
argmax | 求最大值在原始输入矩阵中的索引位置 |