对任意shape的label用numpy进行one_hot编码

def get_one_hot(labels, nb_classes):
    res = np.eye(nb_classes)[np.array(labels).reshape(-1)]
    return res.reshape(list(labels.shape)+[nb_classes])

解释:

  1. np.array(labels).reshape(-1)是将labels展平, 比如将[[2,1],[3,2],[0,0]](shape为[3,2])展平为[2,1,3,2,0,0](shape为[6, ])
  2. np.eye(nb_classes)生成对角矩阵, 左上-右下对角线上的值为1, 其余为0
  3. res = np.eye(nb_classes)[np.array(labels).reshape(-1)]根据展平后的结果, 取np.eyes()中的对应行, 得到新的矩阵
  4. 最后把res按照labels的shape复原, 增加一个nb_class维度

猜你喜欢

转载自blog.csdn.net/weixin_42561002/article/details/87861640
今日推荐