起源:代码的类名为“非线性”,我看了一下,就是nn.Linear() 与激活函数的叠加,我寻思nn.Linear() 不是线性层吗?为什么把这个类叫做“非线性”?
class NonLinear(nn.Module):
"""
该模块实现了一个基于全连接层的非线性函数映射操作,可用于将输入张量映射到高维空间中,并提取出不同的特征。
"""
def __init__(self, input, output_size, hidden=None):
super(NonLinear, self).__init__()
if hidden is None:
hidden = input
self.layer1 = Linear(input, hidden, init="relu")
self.layer2 = Linear(hidden, output_size, init="final")
def forward(self, x):
x = self.layer1(x)
x = F.gelu(x)
x = self.layer2(x)
return x
# 在 zero_init 方法中,将 layer2 层的权重和偏置初始化为 0,可以防止梯度爆炸的问题
def zero_init(self):
nn.init.zeros_(self.layer2.weight)
nn.init.zeros_(self.layer2.bias)
在神经网络中,一般通过激活函数将线性变换的结果进行非线性化,进而增加模型的表达能力。如果没有使用激活函数,那么多个线性层堆叠起来也仍然是一个线性变换。
因为 nn.Linear() 实质上是一个线性变换操作,只有激活函数的添加才能使得输出非线性化。总之,使用 nn.Linear() 配合激活函数可以构建非线性深度神经网络,从而拟合更加复杂的数据分布和函数关系,提高分类和预测的准确性。