torch.nn.Linear

torch.nn.Linear

torch.nn.Linear 类用于定义模型的线性层,即完成前面提到的不同的层之间的线性变换。

torch.nn.Linear 类接收的参数有三个,分别是输入特征数、输出特征数和是否使用偏置,设置是否使用偏置的参数是一个布尔值,默认为 True ,即使用偏置。

在实际使用的过程中,我们只需将输入的特征数和输出的特征数传递给 torch.nn.Linear 类,就会自动生成对应维度的权重参数和偏置,对于生成的权重参数和偏置,我们的模型默认使用了一种比之前的简单随机方式更好的参数初始化方法。

根据我们搭建模型的输入、输出和层次结构需求,它的输入是在一个批次中包含 100 个特征数为 1000 的数据,最后得到 100 个特征数为 10 的输出数据,中间需要经过两次线性变换,所以要使用两个线性层,两个线性层的代码分别是

torch.nn.Linear(input_data,hidden_layer)
torch.nn.Linear(hidden_layer, output_data)
可看到,其代替了之前使用矩阵乘法方式的实现,代码更精炼、简洁。

猜你喜欢

转载自blog.csdn.net/weixin_44039930/article/details/121779061
今日推荐