PyTorch implements simple logistic regression

1. Implementation code

Similar to https://blog.csdn.net/weixin_43821559/article/details/123298468 , using GPU training, the code is as follows:

import torch
import matplotlib.pyplot as plt
import numpy as np
import os
os.environ["KMP_DUPLICATE_LIB_OK"]  =  "TRUE"

# 1、准备数据
x_data = torch.Tensor([[1.0],[2.0],[3.0]])
y_data = torch.Tensor([[0],[0],[1]])

# 2、设计模型
class LogisticRegressionModel(torch.nn.Module):
    def __init__(self):
        super(LogisticRegressionModel, self).__init__()
        self.linear = torch.nn.Linear(1, 1)

    def forward(self, x):
        y_pred = torch.sigmoid(self.linear(x))
        return y_pred
model = LogisticRegressionModel()

### 将模型和训练数据加载到GPU上
# 模型加载到GPU上
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model.to(device)
# 数据加载到GPU上
x = x_data.to(device)
y = y_data.to(device)

# 3、构造损失函数和优化器
criterion = torch.nn.BCELoss(reduction='sum')
optimizer = torch.optim.SGD(model.parameters(),lr=0.01)

# 4、训练过程
epoch_list = []
loss_list = []
w_list = []
b_list = []
for epoch in range(1000):
    y_pred = model(x)
    loss = criterion(y_pred, y)
    print(epoch, loss)

    epoch_list.append(epoch)
    loss_list.append(loss.data.item())
    w_list.append(model.linear.weight.item())
    b_list.append(model.linear.bias.item())

    optimizer.zero_grad()  # 梯度归零
    loss.backward()  # 反向传播
    optimizer.step()  # 更新

# 输出权重和偏置
print('w = ',model.linear.weight.item())
print('b = ',model.linear.bias.item())
# 测试模型
x = np.linspace(0,10,200)
x_t = torch.Tensor(x).view(-1,1).to(device)
y_t = model(x_t)
y = y_t.data.cpu().numpy()

plt.plot(x,y,'b')
plt.plot([0,10],[0.5,0.5],c='r')
plt.xlabel('Hours')
plt.ylabel('Probability of Pass')
plt.grid()
plt.xlim([0,10])
plt.show()

# 绘图
plt.plot(epoch_list,loss_list,'b')
plt.xlabel('epoch')
plt.ylabel('loss')
plt.show()

# 三维散点图
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
ax.scatter(w_list,b_list,loss_list,c='r')
#设置坐标轴
ax.set_xlabel('weight')
ax.set_ylabel('bias')
ax.set_zlabel('loss')
plt.show()

operation result:

w =  1.3059688806533813
b =  -3.15130352973938

insert image description hereinsert image description hereinsert image description here

2. References

[1] https://www.bilibili.com/video/BV1Y7411d7Ys?p=6

Guess you like

Origin blog.csdn.net/weixin_43821559/article/details/123299601