Pytorch网络定义

Pytorch网络定义方式:可使用nn.Sequential()定义层
import torch
import torch.nn as nn
import torch.utils.data as Data

class MLPmodel(nn.Module):
    def __init__(self):
        super(MLPmodel,self).__init__()
        #定义隐藏层
        self.hidden = nn.Sequential(
            nn.Linear(13,10),
            nn.ReLU(),
            nn.Linear(10,10),
            nn.ReLU(),
        )
        #预测回归层
        self.regression = nn.Linear(10,1)
    #定义前向传播路径
    def forward(self,x):
        x = self.hidden(x)
        output = self.regression(x)
        return output

mlp = MLPmodel()
print(mlp)
运行代码结果如下:

猜你喜欢

转载自blog.csdn.net/qq_48194187/article/details/121474825