pytorch实现线性回归函数

import torch
import os
import numpy as np
import matplotlib.pyplot as plt 

from torch import nn, optim

os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"
# device = torch.device("cuda:0" if torch.cuda.is_available() else "CPU")
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 
learning_rate = 1e-3
input_size = 1
output_size = 1
epoch_num = 1000

x_train = np.array([[3.3], [4.4], [5.5], [6.71], [6.93], [4.168], 
                    [9.779], [6.182], [7.59], [2.167], [7.042], 
                    [10.791], [5.313], [7.997], [3.1]], dtype=np.float32)
 
y_train = np.array([[1.7], [2.76], [2.09], [3.19], [1.694], [1.573], 
                    [3.366], [2.596], [2.53], [1.221], [2.827], 
                    [3.465], [1.65], [2.904], [1.3]], dtype=np.float32)

class LinearRegression(nn.Module):
    def __init__(self, input_size, output_size):
        super(LinearRegression, self).__init__()
        self.linear = nn.Linear(input_size, output_size)

    def forward(self, x):
        out = self.linear(x)
        return out

model = LinearRegression(input_size, output_size)
# print(model)
criterion = nn.MSELoss().to(device)
optimizer = optim.SGD(model.parameters(), lr=learning_rate)

# 训练

for epoch in range(epoch_num):
    inputs = torch.from_numpy(x_train)
    targets = torch.from_numpy(y_train)

    outputs = model(inputs)
    loss = criterion(outputs, targets)

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    if (epoch+1) % 10 == 0:
        print("Epoch [%d/%d], loss: %.4f"
            %(epoch+1, epoch_num, loss.item()))

plt.figure()
predicted = model(torch.from_numpy(x_train)).detach().numpy()

plt.plot(x_train, y_train, "ro",label="original data")
plt.plot(x_train, predicted, label="predict")
plt.legend()
# plt.xlabel("x_train")
# plt.ylabel("y_train")
plt.show()
Epoch [10/1000], loss: 2.1752
Epoch [20/1000], loss: 0.5747
Epoch [30/1000], loss: 0.3116
Epoch [40/1000], loss: 0.2680
Epoch [50/1000], loss: 0.2605
Epoch [60/1000], loss: 0.2589
Epoch [70/1000], loss: 0.2582
Epoch [80/1000], loss: 0.2577 
Epoch [90/1000], loss: 0.2573 
Epoch [100/1000], loss: 0.2568
Epoch [110/1000], loss: 0.2564
Epoch [120/1000], loss: 0.2559
Epoch [130/1000], loss: 0.2555
Epoch [140/1000], loss: 0.2550
Epoch [150/1000], loss: 0.2546
Epoch [160/1000], loss: 0.2541
Epoch [170/1000], loss: 0.2537
Epoch [180/1000], loss: 0.2533
Epoch [190/1000], loss: 0.2528
Epoch [200/1000], loss: 0.2524
Epoch [210/1000], loss: 0.2520
Epoch [220/1000], loss: 0.2516
Epoch [230/1000], loss: 0.2511
Epoch [240/1000], loss: 0.2507
Epoch [250/1000], loss: 0.2503
Epoch [260/1000], loss: 0.2499
Epoch [270/1000], loss: 0.2495
Epoch [280/1000], loss: 0.2490
Epoch [290/1000], loss: 0.2486
Epoch [300/1000], loss: 0.2482
Epoch [310/1000], loss: 0.2478
Epoch [320/1000], loss: 0.2474
Epoch [330/1000], loss: 0.2470
Epoch [340/1000], loss: 0.2466
Epoch [350/1000], loss: 0.2462
Epoch [360/1000], loss: 0.2458
Epoch [370/1000], loss: 0.2454
Epoch [380/1000], loss: 0.2450
Epoch [390/1000], loss: 0.2446
Epoch [400/1000], loss: 0.2443
Epoch [410/1000], loss: 0.2439
Epoch [420/1000], loss: 0.2435
Epoch [430/1000], loss: 0.2431
Epoch [440/1000], loss: 0.2427
Epoch [450/1000], loss: 0.2423
Epoch [460/1000], loss: 0.2420
Epoch [470/1000], loss: 0.2416
Epoch [480/1000], loss: 0.2412
Epoch [490/1000], loss: 0.2409
Epoch [500/1000], loss: 0.2405
Epoch [510/1000], loss: 0.2401
Epoch [520/1000], loss: 0.2398
Epoch [530/1000], loss: 0.2394
Epoch [540/1000], loss: 0.2390
Epoch [550/1000], loss: 0.2387
Epoch [560/1000], loss: 0.2383
Epoch [570/1000], loss: 0.2380
Epoch [580/1000], loss: 0.2376
Epoch [590/1000], loss: 0.2373
Epoch [600/1000], loss: 0.2369
Epoch [610/1000], loss: 0.2366
Epoch [620/1000], loss: 0.2362
Epoch [630/1000], loss: 0.2359
Epoch [640/1000], loss: 0.2355
Epoch [650/1000], loss: 0.2352
Epoch [660/1000], loss: 0.2348
Epoch [670/1000], loss: 0.2345
Epoch [680/1000], loss: 0.2342
Epoch [690/1000], loss: 0.2338
Epoch [700/1000], loss: 0.2335
Epoch [710/1000], loss: 0.2332
Epoch [720/1000], loss: 0.2328
Epoch [730/1000], loss: 0.2325
Epoch [740/1000], loss: 0.2322
Epoch [750/1000], loss: 0.2319
Epoch [760/1000], loss: 0.2315
Epoch [770/1000], loss: 0.2312
Epoch [780/1000], loss: 0.2309
Epoch [790/1000], loss: 0.2306
Epoch [800/1000], loss: 0.2303
Epoch [810/1000], loss: 0.2299
Epoch [820/1000], loss: 0.2296
Epoch [830/1000], loss: 0.2293
Epoch [840/1000], loss: 0.2290
Epoch [850/1000], loss: 0.2287
Epoch [860/1000], loss: 0.2284
Epoch [870/1000], loss: 0.2281
Epoch [880/1000], loss: 0.2278
Epoch [890/1000], loss: 0.2275
Epoch [900/1000], loss: 0.2272
Epoch [910/1000], loss: 0.2269
Epoch [920/1000], loss: 0.2266
Epoch [930/1000], loss: 0.2263
Epoch [940/1000], loss: 0.2260
Epoch [950/1000], loss: 0.2257
Epoch [960/1000], loss: 0.2254
Epoch [970/1000], loss: 0.2251
Epoch [980/1000], loss: 0.2248
Epoch [990/1000], loss: 0.2246
Epoch [1000/1000], loss: 0.2243

在这里插入图片描述

猜你喜欢

转载自blog.csdn.net/qq_37369201/article/details/109459835