torch.nn.Tanhshrink
原型
CLASS torch.nn.Tanhshrink()
定义
Tanhshrink ( x ) = x − tanh ( x ) \text{Tanhshrink}(x)=x- \text{tanh}(x) Tanhshrink(x)=x−tanh(x)
图
代码
import torch
import torch.nn as nn
m = nn.Tanhshrink()
input = torch.randn(4)
output = m(input)
print("input: ", input)
print("output: ", output)
# input: tensor([-0.3480, 1.0966, 0.2350, 0.1310])
# output: tensor([-0.0134, 0.2973, 0.0042, 0.0007])