torch.nn.Threshold
原型
CLASS torch.nn.Threshold(threshold, value, inplace=False)
参数
- threshold (float) – 阈值
- value (float) – 替换值
- inplace (bool) – 默认为
False
定义
y = { x , if x > threshold value , otherwise y = \begin{cases} x, & \text{if } x > \text{threshold} \\ \text{value}, & \text{otherwise} \end{cases} y={ x,value,if x>thresholdotherwise
代码
import torch
import torch.nn as nn
m = nn.Threshold(0.1, 2)
input = torch.randn(4)
output = m(input)
print("input: ", input)
print("output: ", output)
# input: tensor([-0.5409, -0.2444, 0.2652, 0.6499])
# output: tensor([2.0000, 2.0000, 0.2652, 0.6499])