【PyTorch】教程:torch.nn.Tanhshrink

torch.nn.Tanhshrink

原型

CLASS torch.nn.Tanhshrink()

定义

Tanhshrink ( x ) = x − tanh ( x ) \text{Tanhshrink}(x)=x- \text{tanh}(x) Tanhshrink(x)=xtanh(x)

在这里插入图片描述

代码

import torch
import torch.nn as nn

m = nn.Tanhshrink()
input = torch.randn(4)
output = m(input)

print("input: ", input)
print("output: ", output)
# input:  tensor([-0.3480,  1.0966,  0.2350,  0.1310])
# output:  tensor([-0.0134,  0.2973,  0.0042,  0.0007])

【参考】

Tanhshrink — PyTorch 1.13 documentation

猜你喜欢

转载自blog.csdn.net/zhoujinwang/article/details/129642147