Python手动实现One-hot编码

版权声明:转载注明出处 https://blog.csdn.net/york1996/article/details/82354898
zeros = torch.zeros(real_size,CLASSES_COUNT)
for k in range(real_size ):
    zeros[k, batch_label[k]] = 1

代码解读:

zeros是一个tensor,它的行数是real_size,列数是类别数 CLASSES_COUNT。每一行代表一个样本,每个样本在某一列的值是1,其他是0。所以第一行的代码就是为了全是0的一个tensor。

batch_label是一个列表,其中存放了每行中哪个列值是1,相当于存放了一系列的索引值。元素在batch_label中的索引代表在zeros中的行索引,元素值代表在zeros中列索引。

然后遍历样本数,k行的batch_label[k]置为1。

[0, 4, 2, 1, 1, 4, 1, 0, 2, 0] 这是batch_label
tensor([[1., 0., 0., 0., 0.],这是最终生成的zeros tensor
        [0., 0., 0., 0., 1.],
        [0., 0., 1., 0., 0.],
        [0., 1., 0., 0., 0.],
        [0., 1., 0., 0., 0.],
        [0., 0., 0., 0., 1.],
        [0., 1., 0., 0., 0.],
        [1., 0., 0., 0., 0.],
        [0., 0., 1., 0., 0.],
        [1., 0., 0., 0., 0.]])

猜你喜欢

转载自blog.csdn.net/york1996/article/details/82354898