线性与非线性【在nn.Linear() 后使用激活函数就变为了非线性】

起源:代码的类名为“非线性”,我看了一下,就是nn.Linear() 与激活函数的叠加,我寻思nn.Linear() 不是线性层吗?为什么把这个类叫做“非线性”? 

class NonLinear(nn.Module):
    """
    该模块实现了一个基于全连接层的非线性函数映射操作,可用于将输入张量映射到高维空间中,并提取出不同的特征。
    """
    def __init__(self, input, output_size, hidden=None):
        super(NonLinear, self).__init__()

        if hidden is None:
            hidden = input
        self.layer1 = Linear(input, hidden, init="relu")
        self.layer2 = Linear(hidden, output_size, init="final")

    def forward(self, x):
        x = self.layer1(x)
        x = F.gelu(x)
        x = self.layer2(x)
        return x

    # 在 zero_init 方法中,将 layer2 层的权重和偏置初始化为 0,可以防止梯度爆炸的问题
    def zero_init(self):
        nn.init.zeros_(self.layer2.weight)
        nn.init.zeros_(self.layer2.bias)

在神经网络中,一般通过激活函数将线性变换的结果进行非线性化,进而增加模型的表达能力。如果没有使用激活函数,那么多个线性层堆叠起来也仍然是一个线性变换。

因为 nn.Linear() 实质上是一个线性变换操作,只有激活函数的添加才能使得输出非线性化。总之,使用 nn.Linear() 配合激活函数可以构建非线性深度神经网络,从而拟合更加复杂的数据分布和函数关系,提高分类和预测的准确性。

猜你喜欢

转载自blog.csdn.net/weixin_43135178/article/details/130064244