参考:pytorch教程之nn.Module类详解——使用Module类来自定义模型
一、简单的模型模板
1、定义网络结构
class MyNet(nn.Module):
# 初始化函数 __init__(self):
# 定义了具体网络有那些层,但并没有决定网络的结构。
def __init__(self) -> None:
super().__init__()
# 前向传播 forward():
# 函数定义了网络的的顺序
def forward(self, input):
2、实验一下
前向网络中对传入的值加1
import torch.nn as nn
import torch
class MyNet(nn.Module):
def __init__(self) -> None:
super().__init__()
def forward(self, input):
# input输入,output输出
output = input + 1
return output
x = torch.tensor(1.0)
# 初始化网络
MyNet = MyNet()
output = MyNet(x)
print(output)
输出:
二、自定义网络模型
class MyNet(nn.Module):
def __init__(self) -> None:
super(MyNet, self).__init__()
self.conv1 = Conv2d(in_channels=3, out_channels=6, kernel_size=3, stride=1, padding=0)
# 卷积层中stride默认为1,池化层中stride默认为kernel_size的大小
def forward(self, x):
x = self.conv1(x)
return x