pytorch系列: nn.Modlue及nn.Linear 源码理解

先看一个列子:

import torch
from torch import nn

m = nn.Linear(20, 30)
input = torch.randn(128, 20)
output = m(input)

output.size()

out:

torch.Size([128, 30])

刚开始看这份代码是有点迷惑的,m是类对象,而直接像函数一样调用m,m(input)

重点:

     nn.Module 是所有神经网络单元(neural network modules)的基类
     pytorch在nn.Module中,实现了__call__方法,而在__call__方法中调用了forward函数。

经过以上两点。上述代码就不难理解。

接下来看一下源码:
https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/module.py

再来看一下nn.Linear
https://pytorch.org/docs/stable/_modules/torch/nn/modules/linear.html
主要看一下forward函数:


返回的是:
input∗weight+bias input * weight + bias
input∗weight+bias
的线性函数

此时再看一下这一份代码:

import torch
from torch import nn

m = nn.Linear(20, 30)
input = torch.randn(128, 20)
output = m(input)

output.size()

# define three layers
class simpleNet(nn.Module):

    def __init__(self, in_dim, n_hidden_1, n_hidden_2, out_dim):
        super().__init__()
        self.layer1 = nn.Linear(in_dim, n_hidden_1)
        self.layer2 = nn.Linear(n_hidden_1, n_hidden_2)
        self.layer3 = nn.Linear(n_hidden_2, out_dim)

    def forward(self, x):
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)

        return x

以下为各层神经元个数:
输入: in_dim
第一层: n_hidden_1
第二层:n_hidden_2
第三层(输出层):out_dim

转自:https://blog.csdn.net/dss_dssssd/article/details/82977170

猜你喜欢

转载自blog.csdn.net/qq_36652619/article/details/86608788
今日推荐