使用numpy进行one_hot编码

import numpy as np
def transform_one_hot(labels):
  n_labels = np.max(labels) + 1
  one_hot = np.eye(n_labels)[labels]
  return one_hot
labels = np.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
one_hot = transform_one_hot(labels)
print(one_hot)

https://blog.csdn.net/zhongranxu/article/details/79332154

猜你喜欢

转载自blog.csdn.net/qq_38826019/article/details/83184134