pytorch 向量转化为one-hot编码

#pytorch 向量转化为one-hot编码
import torch

#原始向量
index = torch.tensor([[1], [2], [0], [3]])
onehot = torch.zeros(4, 4)
onehot.scatter_(1, index, 1)
print(onehot)

#结果
tensor([[0., 1., 0., 0.],
		[0., 0., 1., 0.],
		[1., 0., 0., 0.],
		[0., 0., 0., 1.]])

猜你喜欢

转载自blog.csdn.net/ao1886/article/details/108981585
今日推荐