[Pytorch]ワンホットエンコーディングにscatter_()を使用する

サンプルコード

>>> import torch
>>> label = torch.arange(10).view(-1, 1)
>>> label
tensor([[0],
        [1],
        [2],
        [3],
        [4],
        [5],
        [6],
        [7],
        [8],
        [9]])
>>> label_onehot = torch.zeros(10, 10).scatter_(1, label, 1)
>>> label_onehot
tensor([[1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 1., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 1., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 1., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 1., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 1., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 1., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 1.]])

チップ

scatter()とscatter_()の役割は同じですが、scatter()が元のTensorを直接変更しないのに対し、scatter_()は変更する点が異なります。

引用文献

https://pytorch.org/docs/master/tensors.html#torch.Tensor.scatter_

おすすめ

転載: blog.csdn.net/qq_42951560/article/details/114924366