nn.linear()函数

import torch
import torch.nn as nn
import torch.nn.functional as F

class LinearFC(nn.Module):

    def __init__(self):
        super(DropoutFC, self).__init__()
        self.fc = nn.Linear(3, 2)

    def forward(self, input):
        out = self.fc(input)
        return out

Net = LinearFC()
x = torch.randint(10, (2, 3)).float()  # 随机生成不大于10的整数,转为float, 因为nn.linear需要float类型数据
Net.train()
output = Net(x)
print(output)

# train the Net

创建了一个最简单的LinearFC模型,里面有一个线性函数nn.Linear(3, 2),线性变换公式为: y = x W T + b y=x W^T + b y=xWT+b

通过Debug,一步一步查看运行情况:

在这里插入图片描述

当前这一步可以看到模型给我们随机初始化了权重 W 2 × 3 W_{2 \times 3} W2×3和偏置 b 2 × 3 b_{2 \times 3} b2×3,为什么权重 W W W的shape是 2 × 3 2\times3 2×3,因为公式里需要转置。

x x x随机生成不大于10的整数,转为float, 因为nn.linear需要float类型数据。
在这里插入图片描述
可以看出使用模型算出来的output,与手动使用公式算出来的结果一致。
在这里插入图片描述

Net.train()的作用

当网络中有 dropout,Batch Normalization 的时候。训练的要记得 Net.train(), 测试 要记得 Net.eval()。

在训练模型时会在前面加上:

Net.train()

在测试模型时在前面使用:

model.eval()

同时发现,如果不写这两个程序也可以运行,这是因为这两个方法是针对在网络训练和测试时采用不同方式的情况,比如Batch Normalization 和 Dropout。

猜你喜欢

转载自blog.csdn.net/vincent_duan/article/details/119934349