【PyTorch】教程:torch.nn.Hardtanh

torch.nn.Hardtanh

原型

CLASS torch.nn.Hardtanh(min_val=- 1.0, max_val=1.0, inplace=False, min_value=None, max_value=None)

参数

  • min_val ([float]) – 线性区域的最小值,默认为 -1
  • max_val ([float]) – 线性区域的最大值,默认为 1
  • inplace ([bool]) – 默认为 False

定义

HardTanh ( x ) = { max_val  if  x >  max_val  min_val  if  x <  min_val  x  otherwise  \text{HardTanh}(x) = \begin{cases} \text{max\_val} & \text{ if } x > \text{ max\_val } \\ \text{min\_val} & \text{ if } x < \text{ min\_val } \\ x & \text{ otherwise } \\ \end{cases} HardTanh(x)= max_valmin_valx if x> max_val  if x< min_val  otherwise 

在这里插入图片描述

代码

import torch
import torch.nn as nn

m = nn.Hardtanh(-2, 2)
input = torch.randn(2)
output = m(input)
print("input: ", input)    # input:  tensor([2.1926, 0.2211])
print("output: ", output)  # output:  tensor([2.0000, 0.2211])

【参考】

Hardtanh — PyTorch 1.13 documentation

猜你喜欢

转载自blog.csdn.net/zhoujinwang/article/details/129291032