pytorch实现线性回归模型

线性回归模型从零开始的实现

步骤:

  1. 准备数据集
  2. 定义模型
  3. 初始化模型参数
  4. 定义损失函数
  5. 定义优化函数
  6. 训练模型

y = x1 * w1 + x2 + w2 + b

导包

import torch
from IPython import display
from matplotlib import pyplot as plt
import numpy as np
import random

生成数据集

# 特征数:2
num_inputs = 2
# 样本数:1000
num_examples = 1000

# 设置权重
true_w = [2, -3.4]
# 设置偏置
true_b = 4.2

# 1000个特征数为2的样本
features = torch.randn(num_examples, num_inputs,dtype=torch.float32)
# 计算出对应的标签
labels = true_w[0] * features[:, 0] + true_w[1] * features[:, 1] + true_b
# 给标签加噪音
labels += torch.tensor(np.random.normal(0, 0.01, size=labels.size()),dtype=torch.float32)

使用图像来展示生成的数据
在这里插入图片描述
读取数据集

import torch.utils.data as Data

batch_size = 10

# combine featues and labels of dataset
dataset = Data.TensorDataset(features, labels)

# put dataset into DataLoader
data_iter = Data.DataLoader(
    dataset=dataset,            #  Data.TensorDataset(特征, 标签)
    batch_size=batch_size,      # 每批的数据量
    shuffle=True,               # 是否打乱数据
    num_workers=2,              # read data in multithreading
)

定义模型

class LinearNet(nn.Module):
    def __init__(self, n_feature):
        super(LinearNet, self).__init__()      # 继承父类的初始化
        self.linear = nn.Linear(n_feature, 1)  # torch.nn.Linear(in_features, out_features, bias=True)

    def forward(self, x):
        y = self.linear(x)
        return y
    
net = LinearNet(num_inputs)

初始化模型参数

from torch.nn import init

init.normal_(net[0].weight, mean=0.0, std=0.01)
init.constant_(net[0].bias, val=0.0)

定义损失函数
均方误差

loss = nn.MSELoss()

定义优化函数
梯度下降

import torch.optim as optim
optimizer = optim.SGD(net.parameters(), lr=0.03)

训练模型

num_epochs = 3
for epoch in range(1, num_epochs + 1):
    for X, y in data_iter:
        output = net(X)
        l = loss(output, y.view(-1, 1))
        optimizer.zero_grad() # 梯度归零
        l.backward()		  # 误差反向传播
        optimizer.step()	  # 根据误差,优化参数
    print('epoch %d, loss: %f' % (epoch, l.item()))

结果对比

dense = net[0]
print(true_w, true_b)
print(dense.weight.data, dense.bias.data)

在这里插入图片描述

猜你喜欢

转载自blog.csdn.net/wjl__ai__/article/details/108102626