【PyTorch】教程:torch.nn.Threshold

torch.nn.Threshold

原型

CLASS torch.nn.Threshold(threshold, value, inplace=False)

参数

  • threshold (float) – 阈值
  • value (float) – 替换值
  • inplace (bool) – 默认为 False

定义

y = { x , if  x > threshold value , otherwise y = \begin{cases} x, & \text{if } x > \text{threshold} \\ \text{value}, & \text{otherwise} \end{cases} y={ x,value,if x>thresholdotherwise

代码

import torch
import torch.nn as nn

m = nn.Threshold(0.1, 2)
input = torch.randn(4)
output = m(input)

print("input: ", input)
print("output: ", output)

# input:  tensor([-0.5409, -0.2444,  0.2652,  0.6499])
# output:  tensor([2.0000, 2.0000, 0.2652, 0.6499])

【参考】

Threshold — PyTorch 1.13 documentation

猜你喜欢

转载自blog.csdn.net/zhoujinwang/article/details/129642472