【PyTorch】教程:torch.nn.Softshrink

torch.nn.Softshrink

原型

CLASS torch.nn.Softshrink(lambd=0.5)

参数

  • lambd (float) – λ \lambda λ 为 Softshrink参数,默认为 0.5, 必须不小于0

定义

SoftShrinkage ( x ) = { x − λ , if  x > λ x + λ , if  x < − λ 0 , otherwise \text{SoftShrinkage}(x)=\begin{cases} x-\lambda, & \text{if } x > \lambda \\ x+\lambda, & \text{if } x < -\lambda \\ 0, & \text{otherwise} \end{cases} SoftShrinkage(x)= xλ,x+λ,0,if x>λif x<λotherwise

在这里插入图片描述

代码

import torch
import torch.nn as nn

m = nn.Softshrink()
input = torch.randn(4)
output = m(input)

print("input: ", input)
print("output: ", output)

# input:  tensor([ 0.9876, -2.0183, -0.7573, -1.7960])
# output:  tensor([ 0.4876, -1.5183, -0.2573, -1.2960])

【参考】

Softshrink — PyTorch 1.13 documentation

猜你喜欢

转载自blog.csdn.net/zhoujinwang/article/details/129641865