2.7 花哨的索引

2.7 花哨的索引

import numpy as np
rand = np.random.RandomState(42)
x = rand.randint(100, size=10)
print(x)
[51 92 14 71 60 20 82 86 74 74]
[x[3], x[7], x[2]]
[71, 86, 14]
ind = [3, 7, 2]
x[ind]
array([71, 86, 14])

利用花哨的索引,结果的形状与索引数组一致,而不是与被索引数组的形状一致。

ind = np.array([[3, 7], [4, 5]])
x[ind]
array([[71, 86],
       [60, 20]])
X = np.arange(12).reshape((3, 4))
X
array([[ 0,  1,  2,  3],
       [ 4,  5,  6,  7],
       [ 8,  9, 10, 11]])

二维索引,对应的是行和列的索引,如果索引的维度不同,会广播后再索引。

row = np.array([0, 1, 2])
col = np.array([2, 1, 3])
X[row, col]
array([ 2,  5, 11])
X[row[:, np.newaxis], col]  # 索引是3X1和1x3,先广播再索引
array([[ 2,  1,  3],
       [ 6,  5,  7],
       [10,  9, 11]])
row[:, np.newaxis], col
(array([[0],
        [1],
        [2]]), array([2, 1, 3]))
X[2, [2, 0, 1]]  # 组合使用,与简单索引
array([10,  8,  9])
X[1:, [2, 0, 1]]  # 组合使用,与切片
array([[ 6,  4,  5],
       [10,  8,  9]])
mask = np.array([1, 0, 1, 0], dtype=bool)
X[row[:, np.newaxis], mask]  # 组合使用,与掩码
array([[ 0,  2],
       [ 4,  6],
       [ 8, 10]])

示例:选择随机点

花哨的索引的常见用途是从一个矩阵中选择行的子集,如有一个 N×D

的矩阵,表示在 D 个维度中的 N

个点。以下是一个二维正态分布的点组成的数组:

mean = [0, 0]
cov = [[1, 2], [2, 5]]
X = rand.multivariate_normal(mean, cov, 100)
X.shape
(100, 2)

该数组为100行2列的二维数组,画出散点:

%matplotlib inline
import matplotlib.pyplot as plt
import seaborn; seaborn.set()
plt.scatter(X[:, 0], X[:, 1]);

X  # 二维数组的内容
array([[-0.644508  , -0.46220608],
       [ 0.7376352 ,  1.21236921],
       [ 0.88151763,  1.12795177],
       [ 2.04998983,  5.97778598],
       [-0.1711348 , -2.06258746],
       [ 0.67956979,  0.83705124],
       [ 1.46860232,  1.22961093],
       [ 0.35282131,  1.49875397],
       [-2.51552505, -5.64629995],
       [ 0.0843329 , -0.3543059 ],
       [ 0.19199272,  1.48901291],
       [-0.02566217, -0.74987887],
       [ 1.00569227,  2.25287315],
       [ 0.49514263,  1.18939673],
       [ 0.0629872 ,  0.57349278],
       [ 0.75093031,  2.99487004],
       [-3.0236127 , -6.00766046],
       [-0.53943081, -0.3478899 ],
       [ 1.53817376,  1.99973464],
       [-0.50886808, -1.81099656],
       [ 1.58115602,  2.86410319],
       [ 0.99305043,  2.54294059],
       [-0.87753796, -1.15767204],
       [-1.11518048, -1.87508012],
       [ 0.4299908 ,  0.36324254],
       [ 0.97253528,  3.53815717],
       [ 0.32124996,  0.33137032],
       [-0.74618649, -2.77366681],
       [-0.88473953, -1.81495444],
       [ 0.98783862,  2.30280401],
       [-1.2033623 , -2.04402725],
       [-1.51101746, -3.2818741 ],
       [-2.76337717, -7.66760648],
       [ 0.39158553,  0.87949228],
       [ 0.91181024,  3.32968944],
       [-0.84202629, -2.01226547],
       [ 1.06586877,  0.95500019],
       [ 0.44457363,  1.87828298],
       [ 0.35936721,  0.40554974],
       [-0.90649669, -0.93486441],
       [-0.35790389, -0.52363012],
       [-1.33461668, -3.03203218],
       [ 0.02815138,  0.79654924],
       [ 0.37785618,  0.51409383],
       [-1.06505097, -2.88726779],
       [ 2.32083881,  5.97698647],
       [ 0.47605744,  0.83634485],
       [-0.35490984, -1.03657119],
       [ 0.57532883, -0.79997124],
       [ 0.33399913,  2.32597923],
       [ 0.6575612 , -0.22389518],
       [ 1.3707365 ,  2.2348831 ],
       [ 0.07099548, -0.29685467],
       [ 0.6074983 ,  1.47089233],
       [-0.34226126, -1.10666237],
       [ 0.69226246,  1.21504303],
       [-0.31112937, -0.75912097],
       [-0.26888327, -1.89366817],
       [ 0.42044896,  1.85189522],
       [ 0.21115245,  2.00781492],
       [-1.83106042, -2.91352836],
       [ 0.7841796 ,  1.97640753],
       [ 0.10259314,  1.24690575],
       [-1.91100558, -3.66800923],
       [ 0.13143756, -0.07833855],
       [-0.1317045 , -1.64159158],
       [-0.14547282, -1.34125678],
       [-0.51172373, -1.40960773],
       [ 0.69758045,  0.72563649],
       [ 0.11677083,  0.88385162],
       [-1.16586444, -2.24482237],
       [-2.23176235, -2.63958101],
       [ 0.37857234,  0.69112594],
       [ 0.87475323,  3.400675  ],
       [-0.86864365, -3.03568353],
       [-1.03637857, -1.18469125],
       [-0.53334959, -0.37039911],
       [ 0.30414557, -0.5828419 ],
       [-1.47656656, -2.13046298],
       [-0.31332021, -1.7895623 ],
       [ 1.12659538,  1.49627535],
       [-1.19675798, -1.51633442],
       [-0.75210154, -0.79770535],
       [ 0.74577693,  1.95834451],
       [ 1.56094354,  2.9330816 ],
       [-0.72009966, -1.99780959],
       [-1.32319163, -2.61218347],
       [-2.56215914, -6.08410838],
       [ 1.31256297,  3.13143269],
       [ 0.51575983,  2.30284639],
       [ 0.01374713, -0.11539344],
       [-0.16863279,  0.39422355],
       [ 0.12065651,  1.13236323],
       [-0.83504984, -2.38632016],
       [ 1.05185885,  1.98418223],
       [-0.69144553, -1.56919875],
       [-1.2567603 , -1.125898  ],
       [ 0.09619333, -0.64335574],
       [-0.99658689, -2.35038099],
       [-1.21405259, -1.77693724]])
X[0]  # 二维数组中第0个元素
array([-0.644508  , -0.46220608])
X[0, 0]  # 二维数组中第0个元素的横坐标
-0.6445079962363565
X[:, 0]  # 二维数组中元素的横坐标组成的数组
array([-0.644508  ,  0.7376352 ,  0.88151763,  2.04998983, -0.1711348 ,
        0.67956979,  1.46860232,  0.35282131, -2.51552505,  0.0843329 ,
        0.19199272, -0.02566217,  1.00569227,  0.49514263,  0.0629872 ,
        0.75093031, -3.0236127 , -0.53943081,  1.53817376, -0.50886808,
        1.58115602,  0.99305043, -0.87753796, -1.11518048,  0.4299908 ,
        0.97253528,  0.32124996, -0.74618649, -0.88473953,  0.98783862,
       -1.2033623 , -1.51101746, -2.76337717,  0.39158553,  0.91181024,
       -0.84202629,  1.06586877,  0.44457363,  0.35936721, -0.90649669,
       -0.35790389, -1.33461668,  0.02815138,  0.37785618, -1.06505097,
        2.32083881,  0.47605744, -0.35490984,  0.57532883,  0.33399913,
        0.6575612 ,  1.3707365 ,  0.07099548,  0.6074983 , -0.34226126,
        0.69226246, -0.31112937, -0.26888327,  0.42044896,  0.21115245,
       -1.83106042,  0.7841796 ,  0.10259314, -1.91100558,  0.13143756,
       -0.1317045 , -0.14547282, -0.51172373,  0.69758045,  0.11677083,
       -1.16586444, -2.23176235,  0.37857234,  0.87475323, -0.86864365,
       -1.03637857, -0.53334959,  0.30414557, -1.47656656, -0.31332021,
        1.12659538, -1.19675798, -0.75210154,  0.74577693,  1.56094354,
       -0.72009966, -1.32319163, -2.56215914,  1.31256297,  0.51575983,
        0.01374713, -0.16863279,  0.12065651, -0.83504984,  1.05185885,
       -0.69144553, -1.2567603 ,  0.09619333, -0.99658689, -1.21405259])

用花哨的索引选择随机而不重复的20个索引值,并用这些索引值选择原始数组对应的值:

indices = np.random.choice(X.shape[0], 20, replace=False)
indices
array([94, 76, 22,  0, 77, 36, 32, 58, 54, 70, 50, 92, 44, 38, 65, 46, 79,
       68, 67, 71])
selection = X[indices]  # 花哨的索引
selection.shape
(20, 2)
plt.scatter(X[:, 0], X[:, 1], alpha=0.3)
plt.scatter(selection[:, 0], selection[:, 1], facecolor='none', edgecolor='b', s=200);

用花哨的索引修改值

x = np.arange(10)
i = np.array([2, 1, 8, 4])
x[i] = 99
x
array([ 0, 99, 99,  3, 99,  5,  6,  7, 99,  9])
x[i] -= 10  # 赋值语句
x
array([ 0, 89, 89,  3, 89,  5,  6,  7, 89,  9])
x[[0, 0]]  # 索引是个数组,依次索引0和0,相当于索引第0个值两次
array([0, 0])
x[[0, 0]] = [4, 6]  # 重复索引,赋值的4会被6覆盖
x
array([ 6, 89, 89,  3, 89,  5,  6,  7, 89,  9])

猜你喜欢

转载自blog.csdn.net/ceerfuce/article/details/81226151
2.7
今日推荐