四、PyTorch 深度学习 用PyTorch实现线性回归

第5讲 用PyTorch实现线性回归

来源:B站 刘二大人

源代码:

import torch

# prepare dataset
# x,y是矩阵,3行1列 也就是说总共有3个数据,每个数据只有1个特征
x_data = torch.Tensor([[1.0], [2.0], [3.0]])
y_data = torch.Tensor([[2.0], [4.0], [6.0]])

# design model using class

class LinearModel(torch.nn.Module):
    def __init__(self):
        super(LinearModel, self).__init__()
        # (1,1)是指输入x和输出y的特征维度,这里数据集中的x和y的特征都是1维的
        # 该线性层需要学习的参数是w和b  获取w/b的方式分别是~linear.weight/linear.bias
        self.linear = torch.nn.Linear(1, 1)

    def forward(self, x):
        y_pred = self.linear(x)
        return y_pred


model = LinearModel()

# construct loss and optimizer
# criterion = torch.nn.MSELoss(size_average = False)
criterion = torch.nn.MSELoss(reduction='sum')
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)  # model.parameters()自动完成参数的初始化操作

# training cycle forward, backward, update
for epoch in range(100):
    y_pred = model(x_data)  # forward:predict
    loss = criterion(y_pred, y_data)  # forward: loss
    print(epoch, loss.item())

    optimizer.zero_grad()  # the grad computer by .backward() will be accumulated. so before backward, remember set the grad to zero
    loss.backward()  # backward: autograd,自动计算梯度
    optimizer.step()  # update 参数,即更新w和b的值

print('w = ', model.linear.weight.item())
print('b = ', model.linear.bias.item())

x_test = torch.Tensor([[4.0]])
y_test = model(x_test)
print('y_pred = ', y_test.data)

结果

0 37.81785202026367
1 17.019359588623047
2 7.757819175720215
3 3.6322386264801025
4 1.793077826499939
5 0.9718062281608582
6 0.6037052273750305
7 0.43737754225730896
8 0.360909640789032
9 0.32448017597198486
10 0.30590832233428955
11 0.2953204810619354
12 0.28831955790519714
13 0.28294938802719116
14 0.27833670377731323
15 0.2740930914878845
16 0.2700458765029907
17 0.266116738319397
18 0.2622707486152649
19 0.2584918439388275
20 0.25477272272109985
21 0.2511092722415924
22 0.24749961495399475
23 0.2439422309398651
24 0.2404361367225647
25 0.2369808405637741
26 0.2335747927427292
27 0.2302180528640747
28 0.2269096076488495
29 0.22364839911460876
30 0.2204342484474182
31 0.21726639568805695
32 0.21414369344711304
33 0.21106618642807007
34 0.20803290605545044
35 0.20504306256771088
36 0.20209631323814392
37 0.19919180870056152
38 0.1963292956352234
39 0.19350768625736237
40 0.19072668254375458
41 0.18798553943634033
42 0.18528401851654053
43 0.1826210618019104
44 0.1799963414669037
45 0.1774098426103592
46 0.17485976219177246
47 0.17234686017036438
48 0.16986995935440063
49 0.1674286425113678
50 0.16502246260643005
51 0.16265088319778442
52 0.16031330823898315
53 0.15800945460796356
54 0.15573850274085999
55 0.15350015461444855
56 0.15129432082176208
57 0.1491198092699051
58 0.14697690308094025
59 0.1448644995689392
60 0.14278265833854675
61 0.14073050022125244
62 0.13870814442634583
63 0.13671457767486572
64 0.13474977016448975
65 0.13281317055225372
66 0.1309044063091278
67 0.12902316451072693
68 0.12716907262802124
69 0.12534113228321075
70 0.12354008108377457
71 0.12176437675952911
72 0.12001438438892365
73 0.11828980594873428
74 0.11658982932567596
75 0.11491402983665466
76 0.11326255649328232
77 0.11163493990898132
78 0.11003047227859497
79 0.10844922065734863
80 0.10689061135053635
81 0.10535458475351334
82 0.10384047031402588
83 0.10234814882278442
84 0.10087712109088898
85 0.09942737966775894
86 0.09799840301275253
87 0.09658998250961304
88 0.09520183503627777
89 0.09383369982242584
90 0.09248516708612442
91 0.09115590155124664
92 0.08984582126140594
93 0.0885547548532486
94 0.08728198707103729
95 0.08602754771709442
96 0.08479107916355133
97 0.08357279747724533
98 0.08237159252166748
99 0.0811876729130745
w =  1.8103132247924805
b =  0.43120279908180237
y_pred =  tensor([[7.6725]])

大家也可以再参考一下错错莫安好家的小朋友两位的博客,本文没有再添加细节具体介绍,后续可能会添加。

猜你喜欢

转载自blog.csdn.net/weixin_46087812/article/details/114133075