concept
Linear regression analysis is the relationship between a variable and a further (poly) variables
- Dependent variable: y
- Arguments: x
- Relationship: Linear
- The expression: y = wx + b
- Objective: Solution b and w
Solving steps:
- Determining Model
Model: y = wx + b - Select loss function
mean square error MSE: - Solving the gradient and updating w, b
gradient descent:
W = W - w.grad the LR *
B = B - * w.grad the LR
the LR increments, learning rate
import torch
import matplotlib.pyplot as plt
torch.manual_seed(10) # 初始化随机数种子,保证结果可以复现
lr = 0.1 # 学习率
# 创建训练数据
x = torch.rand(20, 1) * 10
y = 2 * x + (5 + torch.randn(20, 1)) # torch.randn(20, 1)加入噪声
# 初始化w和b
w = torch.randn(1, requires_grad=True)
b = torch.zeros(1, requires_grad=True)
# 开始迭代
for i in range(1000):
# 前向传播
wx = torch.mul(w, x)
y_pre = torch.add(wx, b) # 预测值
# 计算损失
loss = (0.5 * (y - y_pre) ** 2).mean() # 乘以0.5是为了求导过程中消除平方2的影响,mean()求均值
# 反向传播
loss.backward() # 自动求导,得到梯度
# 更新参数
b.data.sub_(lr * b.grad)
w.data.sub_(lr * w.grad)
# 绘图
if loss.data.numpy() < 1:
plt.scatter(x.data.numpy(), y.data.numpy())
plt.plot(x.data.numpy(), y_pre.data.numpy(), "r-", lw=5)
plt.text(2, 10, "loss=%.4f" % loss.data.numpy(), fontdict={'size': 20, 'color': 'red'})
plt.xlim(1.5, 10)
plt.ylim(8, 28)
plt.title("i:{} w:{} b:{}".format(i, w.data.numpy(), b.data.numpy()))
plt.pause(0.5)
break