PyTorch achieve linear regression

concept

Linear regression analysis is the relationship between a variable and a further (poly) variables

  • Dependent variable: y
  • Arguments: x
  • Relationship: Linear
  • The expression: y = wx + b
  • Objective: Solution b and w

Solving steps:

  1. Determining Model
    Model: y = wx + b
  2. Select loss function
    mean square error MSE: 1 m i = 1 m ( Y i y i ^ ) 2 \frac{1}{m}\sum_{i=1}^{m}(y_i - \hat{y_i})^2
  3. Solving the gradient and updating w, b
    gradient descent:
    W = W - w.grad the LR *
    B = B - * w.grad the LR
    the LR increments, learning rate
import torch
import matplotlib.pyplot as plt

torch.manual_seed(10)  # 初始化随机数种子,保证结果可以复现
lr = 0.1  # 学习率
# 创建训练数据
x = torch.rand(20, 1) * 10
y = 2 * x + (5 + torch.randn(20, 1))  # torch.randn(20, 1)加入噪声
# 初始化w和b
w = torch.randn(1, requires_grad=True)
b = torch.zeros(1, requires_grad=True)
# 开始迭代
for i in range(1000):
    # 前向传播
    wx = torch.mul(w, x)
    y_pre = torch.add(wx, b)  # 预测值
    # 计算损失
    loss = (0.5 * (y - y_pre) ** 2).mean()  # 乘以0.5是为了求导过程中消除平方2的影响,mean()求均值
    # 反向传播
    loss.backward()  # 自动求导,得到梯度
    # 更新参数
    b.data.sub_(lr * b.grad)
    w.data.sub_(lr * w.grad)
    # 绘图
    if loss.data.numpy() < 1:
        plt.scatter(x.data.numpy(), y.data.numpy())
        plt.plot(x.data.numpy(), y_pre.data.numpy(), "r-", lw=5)
        plt.text(2, 10, "loss=%.4f" % loss.data.numpy(), fontdict={'size': 20, 'color': 'red'})
        plt.xlim(1.5, 10)
        plt.ylim(8, 28)
        plt.title("i:{}  w:{}  b:{}".format(i, w.data.numpy(), b.data.numpy()))
        plt.pause(0.5)
        break

Here Insert Picture Description

Released nine original articles · won praise 0 · Views 296

Guess you like

Origin blog.csdn.net/SakuraHimi/article/details/104579832