Convierta el vector de pytorch en codificación one-hot

#pytorch 向量转化为one-hot编码
import torch

#原始向量
index = torch.tensor([[1], [2], [0], [3]])
onehot = torch.zeros(4, 4)
onehot.scatter_(1, index, 1)
print(onehot)

#结果
tensor([[0., 1., 0., 0.],
		[0., 0., 1., 0.],
		[1., 0., 0., 0.],
		[0., 0., 0., 1.]])

Supongo que te gusta

Origin blog.csdn.net/ao1886/article/details/108981585
Recomendado
Clasificación