Pytorch 学习(五):Pytorch 实现多层感知机(MLP)

Pytorch 实现多层感知机(MLP)

本方法总结自《动手学深度学习》(Pytorch版)github项目

实现多层感知器(Multlayer Perceptron)同样遵循以下步骤:

  • 数据集读取
  • 模型搭建和参数初始化
  • 损失函数和下降器构建
  • 模型训练

方法一:从零开始实现

import torch
import torch.nn as nn
import numpy as np
import d2lzh_pytorch as d2l

# 各层节点数
num_i = 28 * 28
num_h = 256
num_o = 10

# 构建数据
batch_size = 256
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)

# 参数初始化
w1 = torch.tensor(np.random.normal(0, 0.01, (num_i, num_h)), dtype=torch.float32, requires_grad=True)
b1 = torch.zeros(num_h, requires_grad=True)
w2 = torch.tensor(np.random.normal(0, 0.01, (num_h, num_o)), dtype=torch.float32, requires_grad=True)
b2 = torch.zeros(num_o, requires_grad=True)
params = [w1, b1, w2, b2]

# 激活函数
def relu(x):
    return torch.max(x, torch.tensor(0.0))

# 模型构建
def net(x):
    x = x.view(-1, num_i)
    h = relu(x.mm(w1) + b1)
    o = h.mm(w2) + b2
    return o

# 损失函数
loss = nn.CrossEntropyLoss()

# 训练模型
num_epochs = 5
lr = 100.0
d2l.train_ch3(net, train_iter, test_iter, loss, num_epochs, batch_size, params, lr)

方法二:能调包就不实现

import torch
import torch.nn as nn
import torch.nn.init as init
import torch.optim as optim
import d2lzh_pytorch as d2l

# node number of MLP Layer
num_i, num_h, num_o = 28 * 28, 256, 10

# data load
batch_size = 256
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)

# network build
class MLP(nn.Module):
    def __init__(self, n_i, n_h, n_o):
        super(MLP, self).__init__()
        self.flatten = d2l.FlattenLayer()
        self.linear1 = nn.Linear(n_i, n_h)
        self.relu = nn.ReLU()
        self.linear2 = nn.Linear(n_h, n_o)

    def forward(self, input):
        return self.linear2(self.relu(self.linear1(self.flatten(input))))

net = MLP(num_i, num_h, num_o)
for param in net.parameters():
    init.normal_(param, mean=0, std=0.01)

# loss
loss = nn.CrossEntropyLoss()

# optimizer
optimizer = optim.SGD(net.parameters(), lr=0.5)

# train
num_epochs = 5
d2l.train_ch3(net, train_iter, test_iter, loss, num_epochs, batch_size, optimizer=optimizer)

猜你喜欢

转载自blog.csdn.net/qq_40491305/article/details/106756621