Pytorch - マスクドフィルメソッドパラメータの詳細な説明と使用法

1 torch.Tensor.masked_fill パラメータの詳細な説明と使用法

1.1 torch.Tensor.masked_fillパラメータの詳細説明

1. 機能的な形態

torch.Tensor.masked_fill(mask, value)

2. 関数関数
入力マスクマスクma sk は、現在のベース テンソルの形状と一致している必要があります。マスクマスク
しますma s kの True 要素に対応する基本 Tensor の要素がvalue value _ _ _ _

3. 関数パラメータ

  • マスク: マスクは int 型 Tensor (値は 0 または 1) または bool 型 Tensor (値は False または True) のいずれかです。
  • : 浮動小数点、塗りつぶされた値

4. 関数の戻り値
塗りつぶされた Tensor を返します。

1.2 torch.Tensor.masked_fill の使用例

以下は、masked_fill 関数の使用法を示す簡単な例です。まず、4x4 の基本行列を作成し、次に 4x4 の対角行列を作成し、対角線上の基本マシン行列のすべての値を 100 ベースに設定します。対角行列の場合、具体的なコードは次のとおりです。

import torch

if __name__ == '__main__':
    tensor = torch.arange(0,16).view(4,4)
    print('origin tensor:\n{}\n'.format(tensor))

    mask = torch.eye(4,dtype=torch.bool)
    print('mask tensor:\n{}\n'.format(mask))

    tensor = tensor.masked_fill(mask,100)
    print('filled tensor:\n{}'.format(tensor))

出力

origin tensor:
tensor([[ 0,  1,  2,  3],
        [ 4,  5,  6,  7],
        [ 8,  9, 10, 11],
        [12, 13, 14, 15]])

mask tensor:
tensor([[ True, False, False, False],
        [False,  True, False, False],
        [False, False,  True, False],
        [False, False, False,  True]])

filled tensor:
tensor([[100,   1,   2,   3],
        [  4, 100,   6,   7],
        [  8,   9, 100,  11],
        [ 12,  13,  14, 100]])

おすすめ

転載: blog.csdn.net/HW140701/article/details/126341289
おすすめ