PyTorch implementa regresión lineal simple

1. Pasos de implementación

1. Preparar datos

x_data = torch.tensor([[1.0],[2.0],[3.0]])
y_data = torch.tensor([[2.0],[4.0],[6.0]])

2. Modelo de diseño

class LinearModel(torch.nn.Module):
    def __init__(self):
        super(LinearModel,self).__init__()
        self.linear = torch.nn.Linear(1,1)
        
    def forward(self, x):
        y_pred = self.linear(x)
        return y_pred
        
model = LinearModel()  

3. Construya la función de pérdida y el optimizador

criterion = torch.nn.MSELoss(reduction='sum')
optimizer = torch.optim.SGD(model.parameters(),lr=0.01)

4. Proceso de formación

epoch_list = []
loss_list = []
w_list = []
b_list = []
for epoch in range(1000):
    y_pred = model(x_data)					  # 计算预测值
    loss = criterion(y_pred, y_data)	# 计算损失
    print(epoch,loss)
    
    epoch_list.append(epoch)
    loss_list.append(loss.data.item())
    w_list.append(model.linear.weight.item())
    b_list.append(model.linear.bias.item())
    
    optimizer.zero_grad()   # 梯度归零
    loss.backward()         # 反向传播
    optimizer.step()        # 更新

5. Visualización de resultados

Muestre los pesos y sesgos finales:

# 输出权重和偏置
print('w = ',model.linear.weight.item())
print('b = ',model.linear.bias.item())

El resultado es:

w =  1.9998501539230347
b =  0.0003405189490877092

Prueba modelo:

# 测试模型
x_test = torch.tensor([[4.0]])
y_test = model(x_test)
print('y_pred = ',y_test.data)
y_pred =  tensor([[7.9997]])

Trace la curva 2D del valor de pérdida en función del número de iteraciones y el diagrama de dispersión 3D de la pérdida a medida que cambian el peso y el sesgo:

# 二维曲线图
plt.plot(epoch_list,loss_list,'b')
plt.xlabel('epoch')
plt.ylabel('loss')
plt.show()

# 三维散点图
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
ax.scatter(w_list,b_list,loss_list,c='r')
#设置坐标轴
ax.set_xlabel('weight')
ax.set_ylabel('bias')
ax.set_zlabel('loss')
plt.show()

El resultado se muestra a continuación:
inserte la descripción de la imagen aquíinserte la descripción de la imagen aquí

2. Referencias

[1] https://www.bilibili.com/video/BV1Y7411d7Ys?p=5

Supongo que te gusta

Origin blog.csdn.net/weixin_43821559/article/details/123298468
Recomendado
Clasificación