问题
假如数据集有3类,怎么把一个庞大的数组集3类,放在不同的数组里。
分析
首先庞大数据集分类,肯定不能一个一个遍历,而且强烈避免个人的操作,需要借助于numpy处理。
示例
数据集,可以看出数据集为3类,我们要x也分成3类
x = [[1,2],[2,9],[3,9],[4,4],[5,9],[6,6],[7,7],[8,8],[9,9]]
y = [0, 0, 0, 1, 1, 1, 2, 2, 2]
先转化为numpy
x = np.array([[1,2],[2,9],[3,9],[4,4],[5,9],[6,6],[7,7],[8,8],[9,9]])
y = np.array([0, 0, 0, 1, 1, 1, 2, 2, 2])
x
array([[1, 2],
[2, 9],
[3, 9],
[4, 4],
[5, 9],
[6, 6],
[7, 7],
[8, 8],
[9, 9]])
y
array([0, 0, 0, 1, 1, 1, 2, 2, 2])
得到每个类别在y中的标记,也可以说是在x中的标记,value是指类别的名称或者ID,3个类别得到3个标记数组
labels = [y == value for value in range(3)]
print(labels)
[array([ True, True, True, False, False, False, False, False, False]),
array([False, False, False, True, True, True, False, False, False]),
array([False, False, False, False, False, False, True, True, True])]
根据标记数组得到,对应x中的3组数据,记住这里应该是x的row=mask的column,假如x是一维的,x的column=mask的column也可以自适应。
t = [x[ci] for ci in labels]
print(t)
[array([[1, 2],
[2, 9],
[3, 9]]),
array([[4, 4],
[5, 9],
[6, 6]]),
array([[7, 7],
[8, 8],
[9, 9]])]
#%%
# 假如x是一维的,x的column=mask的column也可以自适应。
t1 = [y.T[ci] for ci in labels]
print(t1)
t2 = [y[ci] for ci in labels]
print(t2)
[array([0, 0, 0]), array([1, 1, 1]), array([2, 2, 2])]
[array([0, 0, 0]), array([1, 1, 1]), array([2, 2, 2])]