https://numpy.org/doc/stable/reference/generated/numpy.where.html
用法一
import numpy as np
np.where(condition, x, y)
condition中需要包含一个待处理的ndarray(这里记为A),那么对于A中的每个元素,如果满足条件,则将这个元素替换为x,否则,将这个元素替换为y。一个例子如下:
import numpy as np
A = np.array([1, 2, 3, 4, 5])
res = np.where(A > 3, 1, 0)
print(res)
输出:
[0 0 0 1 1]
即,对于A中的每个元素,逐个对比其是否大于3,是则将对应位置"替换"为1,不是则为0。
用法二
import numpy as np
np.where(condition)
类似的,对于A中的每个元素,检查其是否满足condition,如果是则返回其坐标。一个例子如下:
import numpy as np
A = np.array([1, 2, 3, 4, 5])
res = np.where(A > 3)
print(res)
输出:
(array([3, 4], dtype=int64),)
即下标3和下标4处的值满足条件。可以发现,这里返回的其实是一个tuple,这个tuple的维度与数组本身维度一致,用于返回多维的坐标。此外,通过简单套娃也可以将满足条件的值给取出来:
A = np.array([1, 2, 3, 4, 5])
res = np.where(A > 3)
print(A[res])
输出:
[4 5]
此时效果等同于直接使用A[A>3]。