1 torch.Tensor.masked_fill パラメータの詳細な説明と使用法
1.1 torch.Tensor.masked_fillパラメータの詳細説明
1. 機能的な形態
torch.Tensor.masked_fill(mask, value)
2. 関数関数
入力マスクマスクma sk は、現在のベース テンソルの形状と一致している必要があります。マスクマスク
しますma s kの True 要素に対応する基本 Tensor の要素がvalue value値。 _ _ _ _
3. 関数パラメータ
- マスク: マスクは int 型 Tensor (値は 0 または 1) または bool 型 Tensor (値は False または True) のいずれかです。
- 値: 浮動小数点、塗りつぶされた値
4. 関数の戻り値
塗りつぶされた Tensor を返します。
1.2 torch.Tensor.masked_fill の使用例
以下は、masked_fill 関数の使用法を示す簡単な例です。まず、4x4 の基本行列を作成し、次に 4x4 の対角行列を作成し、対角線上の基本マシン行列のすべての値を 100 ベースに設定します。対角行列の場合、具体的なコードは次のとおりです。
import torch
if __name__ == '__main__':
tensor = torch.arange(0,16).view(4,4)
print('origin tensor:\n{}\n'.format(tensor))
mask = torch.eye(4,dtype=torch.bool)
print('mask tensor:\n{}\n'.format(mask))
tensor = tensor.masked_fill(mask,100)
print('filled tensor:\n{}'.format(tensor))
出力
origin tensor:
tensor([[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11],
[12, 13, 14, 15]])
mask tensor:
tensor([[ True, False, False, False],
[False, True, False, False],
[False, False, True, False],
[False, False, False, True]])
filled tensor:
tensor([[100, 1, 2, 3],
[ 4, 100, 6, 7],
[ 8, 9, 100, 11],
[ 12, 13, 14, 100]])