サンプルコード
>>> import torch
>>> label = torch.arange(10).view(-1, 1)
>>> label
tensor([[0],
[1],
[2],
[3],
[4],
[5],
[6],
[7],
[8],
[9]])
>>> label_onehot = torch.zeros(10, 10).scatter_(1, label, 1)
>>> label_onehot
tensor([[1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 1., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 1., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 1., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 1., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 1., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 1., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 1., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 1., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 1.]])
チップ
scatter()とscatter_()の役割は同じですが、scatter()が元のTensorを直接変更しないのに対し、scatter_()は変更する点が異なります。
引用文献
https://pytorch.org/docs/master/tensors.html#torch.Tensor.scatter_