5. 多种方式实现两层神经网络

一个全连接ReLU神经网络,一个隐藏层,没有bias。用来从x预测y,使用L2 Loss。
公式

这一实现完全使用numpy来计算前向神经网络,loss,和反向传播。

numpy ndarray是一个普通的n维array。它不知道任何关于深度学习或者梯度(gradient)的知识,也不知道计算图(computation graph),只是一种用来计算数学运算的数据结构。

import numpy as np

# N is batch size; D_in is input dimension;
# H is hidden dimension; D_out is output dimension.
# 有64个训练数据,输入维度,中间层,输出维度
N, D_in, H, D_out = 64, 1000, 100, 10

# Create random input and output data
# 随机创建一些训练数据
x = np.random.randn(N, D_in)
y = np.random.randn(N, D_out)

# Randomly initialize weights
w1 = np.random.randn(D_in, H)
w2 = np.random.randn(H, D_out)

learning_rate = 1e-6
for t in range(500):
    # Forward pass: compute predicted y
    h = x.dot(w1)
    h_relu = np.maximum(h, 0)
    y_pred = h_relu.dot(w2)

    # Compute and print loss
    loss = np.square(y_pred - y).sum()
    print(t, loss)

    # Backprop to compute gradients of w1 and w2 with respect to loss
    
    # loss = (y_pred - y) ** 2
    grad_y_pred = 2.0 * (y_pred - y)
    # 
    grad_w2 = h_relu.T.dot(grad_y_pred)
    grad_h_relu = grad_y_pred.dot(w2.T)
    grad_h = grad_h_relu.copy()
    grad_h[h < 0] = 0
    grad_w1 = x.T.dot(grad_h)

    # Update weights
    w1 -= learning_rate * grad_w1
    w2 -= learning_rate * grad_w2

输出:

0 34399246.46047344
1 29023199.257758312
2 25155679.85447208
3 20344203.603057466
4 14771404.625789404
5 9796072.99431371
6 6194144.749997159
7 3948427.3657580013
8 2637928.1726997104
9 1879876.2597949505
10 1424349.925182723
11 1131684.579785501
12 930879.9521737935
13 783503.167740541
14 669981.8287784329
15 579151.6288421676
16 504610.5781504087
17 442295.18952143926
18 389647.44224490353
19 344718.3535892912
20 306120.2245707266
21 272728.24885829526
22 243778.8617292929
23 218485.92082002352
24 196304.70602822883
25 176774.2980280186
26 159509.34934842546
27 144200.52956072442
28 130597.06878493169
29 118484.47548850597
30 107661.24303895692
31 97973.75762285746
32 89291.0096051952
33 81500.46898789635
34 74477.4654945682
35 68139.90452489533
36 62418.87519034026
37 57241.53801123622
38 52545.34658231941
39 48280.5552386464
40 44399.73653914068
41 40864.495617471934
42 37640.08489317873
43 34695.77852549495
44 32004.894008637555
45 29545.09481447049
46 27292.93700341219
47 25232.87780747312
48 23342.570881009553
49 21606.76105421809
50 20015.62357395961
51 18551.83281521863
52 17204.31407669751
53 15962.736948706759
54 14818.242254751764
55 13762.251705340486
56 12787.060032590252
57 11885.95797873141
58 11053.123737613136
59 10282.617503272711
60 9569.805676161515
61 8909.467534754986
62 8297.782408129178
63 7731.121277369748
64 7205.863671952578
65 6718.146999962471
66 6265.473531640673
67 5845.100232214373
68 5454.557838660972
69 5091.658572234415
70 4754.393958028546
71 4440.682575260731
72 4148.70793229529
73 3877.022931816484
74 3624.088506535617
75 3388.5746286682042
76 3169.088547995476
77 2964.637382505168
78 2774.073275503305
79 2596.433385302534
80 2430.76267026859
81 2276.0929913609607
82 2131.752323451521
83 1997.0334011418258
84 1871.251515936368
85 1753.7448614349362
86 1643.919519574932
87 1541.289735192464
88 1445.3733798948592
89 1355.6688030350501
90 1271.7809967407718
91 1193.2972539295215
92 1119.8689894828083
93 1051.1890596219616
94 986.9044505648076
95 926.7286776893059
96 870.3673474483486
97 817.5707566117906
98 768.1200077715573
99 721.7693127164074
100 678.327084576388
101 637.5984844921132
102 599.3471700265131
103 563.480144773489
104 529.8443636950776
105 498.2900261297218
106 468.69076555164696
107 440.9141759159077
108 414.8406348102356
109 390.36201589159975
110 367.377986904459
111 345.794153113363
112 325.53293001498525
113 306.50254567681907
114 288.6084052220689
115 271.79635277136003
116 255.99599171996437
117 241.14900382305748
118 227.1967308563582
119 214.07331707113855
120 201.72960304299005
121 190.1215354419851
122 179.20332461067048
123 168.93031453109492
124 159.26460190296683
125 150.18028740393805
126 141.6310351604903
127 133.5845565128207
128 126.00708973959819
129 118.87233941614235
130 112.15340039819878
131 105.82589029792051
132 99.86499912782936
133 94.24914703127945
134 88.95813583258435
135 83.97204514587689
136 79.27316202972057
137 74.84491080200985
138 70.66991561621043
139 66.7338564785546
140 63.02257660379944
141 59.52313997962988
142 56.22322968024125
143 53.111763701729714
144 50.175933060002905
145 47.40606340622237
146 44.793256923660664
147 42.328047476976025
148 40.00144856286997
149 37.80536851048289
150 35.73343107253782
151 33.77738183530186
152 31.930774425664392
153 30.187309589532056
154 28.541632056826323
155 26.987624733348596
156 25.52056466134328
157 24.135140772349633
158 22.82592867935182
159 21.589141054848028
160 20.42054265142189
161 19.31668484832083
162 18.27373881475532
163 17.288158658562804
164 16.35722592792538
165 15.47708242556972
166 14.64564401606641
167 13.859533873388633
168 13.116375098259695
169 12.413894134407165
170 11.749589132060152
171 11.121654032769754
172 10.527755403701851
173 9.966244729776452
174 9.43532189744742
175 8.933071972010405
176 8.458011087399335
177 8.008768655278885
178 7.583910598664408
179 7.181953229124071
180 6.801495862472622
181 6.441562835264972
182 6.100994571216053
183 5.778713643608878
184 5.473781077364653
185 5.185201108079436
186 4.910778594973227
187 4.651157323562962
188 4.405448410225631
189 4.172946150445695
190 3.953151092232221
191 3.744908981306408
192 3.5478244092334004
193 3.3613194717897596
194 3.1847221013676847
195 3.017529877609752
196 2.859281816477389
197 2.7094552588397214
198 2.567567563134434
199 2.4332311668647577
200 2.3060182574075156
201 2.1855656999370474
202 2.0714842934675706
203 1.963429527723876
204 1.8610826001126994
205 1.7641435814092026
206 1.6723644556551025
207 1.5854221408589446
208 1.5030539196422792
209 1.425000717598524
210 1.351040918987503
211 1.2809841760961524
212 1.2146023546998657
213 1.151699249786657
214 1.09209327048869
215 1.035622604251914
216 0.9821312023078086
217 0.931398400382507
218 0.8833282657308957
219 0.8377648680868359
220 0.7946102894611751
221 0.7536759265506683
222 0.7148752979674705
223 0.6781061363107517
224 0.6432513423030822
225 0.6102056251947237
226 0.5788694783475372
227 0.5491564979966296
228 0.5209877387920085
229 0.49428305592302185
230 0.4689618516832248
231 0.4449467772589872
232 0.4221771538530943
233 0.4005837129745524
234 0.3801063836563353
235 0.3606948589685684
236 0.34227967093321354
237 0.3248108971814142
238 0.3082453595318127
239 0.29253330742340067
240 0.27763364296473714
241 0.26349628059737484
242 0.25008661105388497
243 0.23736707183615013
244 0.2253000957467423
245 0.2138523361624442
246 0.20298921216044458
247 0.19268673137415065
248 0.1829142719891026
249 0.17364132891347556
250 0.1648403028301571
251 0.1564884304843211
252 0.14856550672688007
253 0.14104626799439057
254 0.13391149996667578
255 0.12714129899063353
256 0.12071506432087209
257 0.11461697949344261
258 0.10882972424002835
259 0.10333904654553265
260 0.09812616827155932
261 0.0931780565261463
262 0.08848217530870164
263 0.08402546175986916
264 0.07979485919700613
265 0.07577840283923465
266 0.07196528340443757
267 0.06834633090302783
268 0.06491078821045225
269 0.06164958715982291
270 0.05855323394724137
271 0.055613749333235304
272 0.05282348135771675
273 0.050174186967151216
274 0.04765819132181068
275 0.045271438294485024
276 0.04300319389006948
277 0.04085019680263658
278 0.03880575227345778
279 0.03686467008847949
280 0.035020834986065126
281 0.033270091765219535
282 0.03160778576828356
283 0.03002859285883898
284 0.028528860354370092
285 0.02710476057141562
286 0.025752260418639834
287 0.024468202569374195
288 0.023248340960852463
289 0.022089463426076428
290 0.020988894619024936
291 0.0199437951726752
292 0.018950933816160725
293 0.01800781808698391
294 0.01711204608500793
295 0.016261162547766717
296 0.015452712469537512
297 0.014684969147452612
298 0.013955631973752351
299 0.013262568423046399
300 0.012604847094777376
301 0.011979553027742364
302 0.011385439315700433
303 0.010820996391120769
304 0.010284594931599429
305 0.009774981235579705
306 0.009290757906358697
307 0.008830808784232147
308 0.008393777417063167
309 0.007978536761423337
310 0.007583920779683997
311 0.007208915598860074
312 0.006852552626458435
313 0.006513949532633323
314 0.006192232365425878
315 0.005886426734936253
316 0.0055958430902459
317 0.005319724577409574
318 0.00505742623277029
319 0.004808048657916753
320 0.004571049645157161
321 0.00434578735628591
322 0.00413168646938554
323 0.003928177611013847
324 0.0037348015880361344
325 0.0035511516515125325
326 0.00337649547244047
327 0.003210468217165082
328 0.0030526001815046697
329 0.0029025877721886788
330 0.002759953496700082
331 0.0026243799521717052
332 0.0024955282428992384
333 0.002373029911103254
334 0.0022565698554979346
335 0.0021458674153799783
336 0.0020406457772717784
337 0.0019405998694530123
338 0.0018454766429137211
339 0.001755053210723282
340 0.0016691117022354827
341 0.0015873764219450975
342 0.001509661141624596
343 0.0014357855252587638
344 0.0013655404837908268
345 0.0012987690170491828
346 0.0012352672659498882
347 0.0011748904401224442
348 0.0011174828717238274
349 0.001062929250585763
350 0.0010110289742089754
351 0.0009616701215304243
352 0.0009147416941442063
353 0.0008701229479127731
354 0.0008276969799764454
355 0.0007873380259436494
356 0.0007489547972389264
357 0.000712451804667229
358 0.000677739371845171
359 0.000644728435464334
360 0.0006133353967986715
361 0.000583485133709109
362 0.0005550950412157807
363 0.0005280932731431832
364 0.0005024110280030845
365 0.00047798591228805434
366 0.0004547507774551706
367 0.0004326546423136559
368 0.0004116382458083261
369 0.0003916440959886334
370 0.0003726296356534275
371 0.0003545443586216977
372 0.000337347352488608
373 0.00032099061370803334
374 0.0003054229784132819
375 0.00029061647064382485
376 0.0002765299098361774
377 0.0002631327221101076
378 0.0002503865963973947
379 0.0002382599294869431
380 0.00022672670184804494
381 0.00021575299560298047
382 0.00020531375263207438
383 0.000195381616896771
384 0.00018593500698085453
385 0.00017694494225329907
386 0.00016839225855899982
387 0.00016025517275686525
388 0.00015251350815142156
389 0.0001451491411549753
390 0.0001381428245892601
391 0.00013147417414693054
392 0.00012512977608770297
393 0.00011909308605343111
394 0.00011334857979979945
395 0.00010788480695473414
396 0.00010268704883570024
397 9.773868892276339e-05
398 9.303020197524704e-05
399 8.85491663624475e-05
400 8.428485316645869e-05
401 8.022778747190388e-05
402 7.636668153099922e-05
403 7.269236014951034e-05
404 6.919607836124983e-05
405 6.586822433827189e-05
406 6.270255584885866e-05
407 5.968876673919747e-05
408 5.682053107105221e-05
409 5.409095915616159e-05
410 5.1493171389161614e-05
411 4.902052422909128e-05
412 4.666751255362057e-05
413 4.442778311925004e-05
414 4.229665906499858e-05
415 4.026848185397114e-05
416 3.833731276858471e-05
417 3.6499489902161296e-05
418 3.47508278392673e-05
419 3.308721418833935e-05
420 3.150227517115802e-05
421 2.9993581650856873e-05
422 2.855746554164107e-05
423 2.719082047553995e-05
424 2.5889677603252346e-05
425 2.4650964056310747e-05
426 2.3472062284026367e-05
427 2.234982200401504e-05
428 2.1281362386681205e-05
429 2.026427108363315e-05
430 1.9295935390705645e-05
431 1.837421991578651e-05
432 1.7497081369101593e-05
433 1.6661834251218388e-05
434 1.586646319751679e-05
435 1.5109303894458321e-05
436 1.4388405284291706e-05
437 1.370208486427819e-05
438 1.3048788200557073e-05
439 1.2426643170448882e-05
440 1.1834592393040145e-05
441 1.1270674124878297e-05
442 1.0733978520453249e-05
443 1.0222848634947958e-05
444 9.736143369136732e-06
445 9.27267469036151e-06
446 8.831344021766777e-06
447 8.411269141029111e-06
448 8.01120400023348e-06
449 7.630301830674667e-06
450 7.267514155906956e-06
451 6.922046453620718e-06
452 6.593111073936982e-06
453 6.279880429162564e-06
454 5.981637837795765e-06
455 5.697647106767495e-06
456 5.427145298606709e-06
457 5.169599426636958e-06
458 4.924305367745114e-06
459 4.690669682311872e-06
460 4.468174841019597e-06
461 4.256343870589204e-06
462 4.054581997203786e-06
463 3.862428795746305e-06
464 3.67941417005127e-06
465 3.50524041752366e-06
466 3.339264640996386e-06
467 3.1811641594010322e-06
468 3.0305596775569415e-06
469 2.8871296080016895e-06
470 2.7505179893217363e-06
471 2.6204108493977605e-06
472 2.496484420691004e-06
473 2.3784400359079537e-06
474 2.266014000554243e-06
475 2.1589189914752615e-06
476 2.0569210781071096e-06
477 1.959750825498873e-06
478 1.867203798935969e-06
479 1.7790548691072608e-06
480 1.6950719025924979e-06
481 1.6150665241638997e-06
482 1.5388752061694276e-06
483 1.4662834016989446e-06
484 1.3971344832895556e-06
485 1.3312570753140638e-06
486 1.2684946752376657e-06
487 1.2087236535163552e-06
488 1.1517952524068044e-06
489 1.0975341827852709e-06
490 1.0458478627989778e-06
491 9.966057254836819e-07
492 9.49690406945525e-07
493 9.04995815168244e-07
494 8.624191220382796e-07
495 8.218529042471487e-07
496 7.831982460890191e-07
497 7.463699242750524e-07
498 7.112838972272693e-07
499 6.778580009634641e-07

PyTorch: Tensors

这次我们使用PyTorch tensors来创建前向神经网络,计算损失,以及反向传播。

一个PyTorch Tensor很像一个numpy的ndarray。但是它和numpy ndarray最大的区别是,PyTorch Tensor可以在CPU或者GPU上运算。如果想要在GPU上运算,就需要把Tensor换成cuda类型。

import torch


dtype = torch.float
device = torch.device("cpu")
# device = torch.device("cuda:0") # Uncomment this to run on GPU

# N is batch size; D_in is input dimension;
# H is hidden dimension; D_out is output dimension.
N, D_in, H, D_out = 64, 1000, 100, 10

# Create random input and output data
x = torch.randn(N, D_in, device=device, dtype=dtype)
y = torch.randn(N, D_out, device=device, dtype=dtype)

# Randomly initialize weights
w1 = torch.randn(D_in, H, device=device, dtype=dtype)
w2 = torch.randn(H, D_out, device=device, dtype=dtype)

learning_rate = 1e-6
for t in range(500):
    # Forward pass: compute predicted y
    h = x.mm(w1)
    h_relu = h.clamp(min=0)
    y_pred = h_relu.mm(w2)

    # Compute and print loss
    loss = (y_pred - y).pow(2).sum().item()
    print(t, loss)

    # Backprop to compute gradients of w1 and w2 with respect to loss
    grad_y_pred = 2.0 * (y_pred - y)
    grad_w2 = h_relu.t().mm(grad_y_pred)
    grad_h_relu = grad_y_pred.mm(w2.t())
    grad_h = grad_h_relu.clone()
    grad_h[h < 0] = 0
    grad_w1 = x.t().mm(grad_h)

    # Update weights using gradient descent
    w1 -= learning_rate * grad_w1
    w2 -= learning_rate * grad_w2

输出:

0 31704728.0
1 25331164.0
2 22378086.0
3 19262238.0
4 15348289.0
5 11017595.0
6 7356282.0
7 4705923.5
8 3027346.5
9 2012536.375
10 1409662.25
11 1041771.75
12 807321.0625
13 649262.0
14 536533.1875
15 451980.875
16 385983.53125
17 332925.53125
18 289368.1875
19 253030.78125
20 222354.703125
21 196214.3125
22 173766.515625
23 154378.140625
24 137539.375
25 122867.1015625
26 110037.3515625
27 98769.4921875
28 88842.109375
29 80063.15625
30 72279.015625
31 65361.66796875
32 59195.42578125
33 53687.4453125
34 48757.57421875
35 44338.4453125
36 40370.34765625
37 36803.1484375
38 33587.4453125
39 30684.1640625
40 28059.435546875
41 25683.255859375
42 23528.814453125
43 21570.8515625
44 19792.4296875
45 18175.244140625
46 16704.6640625
47 15364.2578125
48 14141.7509765625
49 13026.609375
50 12007.3115234375
51 11075.3896484375
52 10221.8857421875
53 9439.876953125
54 8722.13671875
55 8063.46826171875
56 7458.20703125
57 6901.8876953125
58 6390.34375
59 5919.4794921875
60 5485.79345703125
61 5086.119140625
62 4718.2138671875
63 4378.970703125
64 4065.92578125
65 3776.7900390625
66 3509.54296875
67 3262.43408203125
68 3033.942626953125
69 2822.52490234375
70 2627.182373046875
71 2446.365966796875
72 2278.8046875
73 2123.408447265625
74 1979.00146484375
75 1845.013427734375
76 1720.6822509765625
77 1605.2548828125
78 1498.001953125
79 1398.356201171875
80 1305.7220458984375
81 1219.5579833984375
82 1139.3939208984375
83 1064.7841796875
84 995.3250732421875
85 930.6298217773438
86 870.3472900390625
87 814.1729125976562
88 761.8153686523438
89 713.0128784179688
90 667.50048828125
91 625.0264892578125
92 585.3772583007812
93 548.3762817382812
94 513.8129272460938
95 481.5259094238281
96 451.376708984375
97 423.1982116699219
98 396.865234375
99 372.23583984375
100 349.208984375
101 327.65960693359375
102 307.49652099609375
103 288.6243591308594
104 270.9569396972656
105 254.41790771484375
106 238.9322052001953
107 224.42202758789062
108 210.82664489746094
109 198.08383178710938
110 186.14157104492188
111 174.94784545898438
112 164.45217895507812
113 154.6090850830078
114 145.38900756835938
115 136.7398681640625
116 128.62008666992188
117 121.001708984375
118 113.84794616699219
119 107.13176727294922
120 100.82424926757812
121 94.90043640136719
122 89.33421325683594
123 84.10637664794922
124 79.19412994384766
125 74.57848358154297
126 70.23960876464844
127 66.15946197509766
128 62.32460403442383
129 58.7183723449707
130 55.32723617553711
131 52.13628387451172
132 49.13447570800781
133 46.310585021972656
134 43.65383529663086
135 41.152828216552734
136 38.799072265625
137 36.583656311035156
138 34.49782943725586
139 32.53558349609375
140 30.6860294342041
141 28.94465446472168
142 27.304447174072266
143 25.759523391723633
144 24.304840087890625
145 22.93392562866211
146 21.641254425048828
147 20.42369842529297
148 19.276079177856445
149 18.194564819335938
150 17.175493240356445
151 16.214174270629883
152 15.308029174804688
153 14.454139709472656
154 13.648143768310547
155 12.88845157623291
156 12.171833038330078
157 11.49567699432373
158 10.85841178894043
159 10.256678581237793
160 9.689424514770508
161 9.154097557067871
162 8.64884090423584
163 8.172189712524414
164 7.721974849700928
165 7.297136306762695
166 6.8962836265563965
167 6.5177459716796875
168 6.160311698913574
169 5.822811126708984
170 5.5043110847473145
171 5.203525066375732
172 4.919389724731445
173 4.651163101196289
174 4.3978190422058105
175 4.158350944519043
176 3.9322471618652344
177 3.718606948852539
178 3.516770839691162
179 3.3262054920196533
180 3.1460940837860107
181 2.975762367248535
182 2.814879894256592
183 2.662900447845459
184 2.5192079544067383
185 2.3834681510925293
186 2.255030393600464
187 2.13377046585083
188 2.0190846920013428
189 1.9105591773986816
190 1.807981014251709
191 1.7110538482666016
192 1.6193859577178955
193 1.5326906442642212
194 1.4506947994232178
195 1.373248815536499
196 1.2998838424682617
197 1.2304624319076538
198 1.1650127172470093
199 1.1028441190719604
200 1.0442299842834473
201 0.9886825084686279
202 0.9362077713012695
203 0.8864397406578064
204 0.8394078016281128
205 0.7948980927467346
206 0.7528337836265564
207 0.7129263281822205
208 0.6751680374145508
209 0.6395058035850525
210 0.6058014035224915
211 0.573722243309021
212 0.5434805750846863
213 0.5148582458496094
214 0.48777079582214355
215 0.462094783782959
216 0.4378334879875183
217 0.41474175453186035
218 0.3928961455821991
219 0.37232136726379395
220 0.35279765725135803
221 0.3343387842178345
222 0.31676602363586426
223 0.3001691997051239
224 0.2844657897949219
225 0.26963645219802856
226 0.2555326223373413
227 0.24219271540641785
228 0.22961300611495972
229 0.21758520603179932
230 0.20622654259204865
231 0.19550156593322754
232 0.18533945083618164
233 0.17566744983196259
234 0.16653285920619965
235 0.15787597000598907
236 0.14970409870147705
237 0.14190873503684998
238 0.13456779718399048
239 0.12759016454219818
240 0.12096268683671951
241 0.11470359563827515
242 0.1087842658162117
243 0.10314527899026871
244 0.09780357778072357
245 0.09277193248271942
246 0.08799058943986893
247 0.0834306925535202
248 0.07912513613700867
249 0.0750374048948288
250 0.07118058204650879
251 0.06751800328493118
252 0.06403960287570953
253 0.06074457988142967
254 0.05762597173452377
255 0.05466882884502411
256 0.0518682561814785
257 0.04920265078544617
258 0.04668186977505684
259 0.044272683560848236
260 0.042025547474622726
261 0.03986666351556778
262 0.037817493081092834
263 0.03588436171412468
264 0.03405837342143059
265 0.03232688084244728
266 0.030674563720822334
267 0.029108863323926926
268 0.027641309425234795
269 0.026221055537462234
270 0.024893635883927345
271 0.02361663617193699
272 0.022424183785915375
273 0.02127956785261631
274 0.020195599645376205
275 0.019174542278051376
276 0.01821214333176613
277 0.01728914864361286
278 0.016413141041994095
279 0.01559178251773119
280 0.014809946529567242
281 0.014066735282540321
282 0.01335320807993412
283 0.012697878293693066
284 0.012057363986968994
285 0.011450453661382198
286 0.010880804620683193
287 0.01034202054142952
288 0.009831000119447708
289 0.009346149861812592
290 0.008878068067133427
291 0.00844407919794321
292 0.008024066686630249
293 0.007635605521500111
294 0.0072587537579238415
295 0.0069105857983231544
296 0.006573254242539406
297 0.006256978493183851
298 0.005949943792074919
299 0.005672339349985123
300 0.005388857331126928
301 0.0051320018246769905
302 0.004887753631919622
303 0.004658843856304884
304 0.0044357734732329845
305 0.004228176549077034
306 0.004027842078357935
307 0.003840153571218252
308 0.0036594069097191095
309 0.003490033093839884
310 0.0033292584121227264
311 0.0031785538885742426
312 0.003031808650121093
313 0.002896031131967902
314 0.0027637507300823927
315 0.002640662482008338
316 0.0025206280406564474
317 0.0024077417328953743
318 0.0022998738568276167
319 0.0022006493527442217
320 0.002101228339597583
321 0.0020116977393627167
322 0.0019245940493419766
323 0.0018393839709460735
324 0.0017626716289669275
325 0.001689193886704743
326 0.0016162614338099957
327 0.0015509161166846752
328 0.0014848458813503385
329 0.0014197597047314048
330 0.0013633108465000987
331 0.0013077231124043465
332 0.001255737035535276
333 0.001203768653795123
334 0.0011556316167116165
335 0.001109623583033681
336 0.0010652983328327537
337 0.0010259364498779178
338 0.0009847141336649656
339 0.0009464982431381941
340 0.0009106658981181681
341 0.0008753696456551552
342 0.0008441813406534493
343 0.0008131045615300536
344 0.0007834337884560227
345 0.0007538509089499712
346 0.0007265112362802029
347 0.0007019559852778912
348 0.0006759824464097619
349 0.0006510045495815575
350 0.0006284148548729718
351 0.0006068204529583454
352 0.0005856421194039285
353 0.0005672844126820564
354 0.0005473798955790699
355 0.0005280547775328159
356 0.0005113428342156112
357 0.0004943141248077154
358 0.00047874540905468166
359 0.00046168401604518294
360 0.000447523663751781
361 0.0004326198832131922
362 0.00041988512384705245
363 0.00040688799344934523
364 0.0003942836483474821
365 0.000381028454285115
366 0.0003701212117448449
367 0.00035913955071009696
368 0.0003480427258182317
369 0.00033798906952142715
370 0.00032894761534407735
371 0.0003196335455868393
372 0.0003099186287727207
373 0.0003019550640601665
374 0.00029274728149175644
375 0.0002844816190190613
376 0.00027625024085864425
377 0.0002687727683223784
378 0.0002608516369946301
379 0.00025311342324130237
380 0.0002469048195052892
381 0.00024049097555689514
382 0.0002342124644201249
383 0.00022811403323430568
384 0.00022231723414734006
385 0.0002166029589716345
386 0.00021077181736472994
387 0.00020510501053649932
388 0.00020020001102238894
389 0.0001948442222783342
390 0.00018990584067068994
391 0.00018529882072471082
392 0.00018070911755785346
393 0.00017650797963142395
394 0.00017214834224432707
395 0.0001683011942077428
396 0.00016451899136882275
397 0.00016050187696237117
398 0.00015686434926465154
399 0.00015321985119953752
400 0.0001501761726103723
401 0.00014639270375482738
402 0.00014274154091253877
403 0.0001396275474689901
404 0.0001364489580737427
405 0.00013346801279112697
406 0.00013024920190218836
407 0.00012755846546497196
408 0.00012532222899608314
409 0.0001224723382620141
410 0.00011974618973908946
411 0.00011740042100427672
412 0.00011441943206591532
413 0.00011229746451135725
414 0.00010995937191182747
415 0.00010784588812384754
416 0.00010610915342113003
417 0.0001038835325744003
418 0.00010166718857362866
419 9.979418973671272e-05
420 9.793229401111603e-05
421 9.590695117367432e-05
422 9.412408689968288e-05
423 9.244915418094024e-05
424 9.07004505279474e-05
425 8.880807581590489e-05
426 8.733373397262767e-05
427 8.574980893172324e-05
428 8.392545714741573e-05
429 8.241042814916e-05
430 8.080529369181022e-05
431 7.939618808450177e-05
432 7.762608584016562e-05
433 7.651503256056458e-05
434 7.531026494689286e-05
435 7.377369183814153e-05
436 7.305829058168456e-05
437 7.153345359256491e-05
438 7.028930122032762e-05
439 6.921228487044573e-05
440 6.803637370467186e-05
441 6.695889896946028e-05
442 6.582081550732255e-05
443 6.459224823629484e-05
444 6.373634096235037e-05
445 6.257549830479547e-05
446 6.140403274912387e-05
447 6.0855145420646295e-05
448 5.961475471849553e-05
449 5.8864923630608246e-05
450 5.767813490820117e-05
451 5.6913944717962295e-05
452 5.6101172958733514e-05
453 5.5124517530202866e-05
454 5.3974403272150084e-05
455 5.3289859351934865e-05
456 5.267193409963511e-05
457 5.193992910790257e-05
458 5.1057362725259736e-05
459 5.001332101528533e-05
460 4.934622847940773e-05
461 4.873232319368981e-05
462 4.794801498064771e-05
463 4.7256213292712346e-05
464 4.667185203288682e-05
465 4.5966684410814196e-05
466 4.526913107838482e-05
467 4.486504985834472e-05
468 4.413699934957549e-05
469 4.358588557806797e-05
470 4.305447146180086e-05
471 4.2635525460354984e-05
472 4.186580190435052e-05
473 4.1199065890396014e-05
474 4.055891258758493e-05
475 4.017409810330719e-05
476 3.963488052249886e-05
477 3.913479667971842e-05
478 3.8683563616359606e-05
479 3.806965833064169e-05
480 3.7681969843106344e-05
481 3.737308725249022e-05
482 3.669063517008908e-05
483 3.630801438703202e-05
484 3.584063597372733e-05
485 3.539228782756254e-05
486 3.4903492633020505e-05
487 3.4551478165667504e-05
488 3.4130087442463264e-05
489 3.377102984813973e-05
490 3.320474206702784e-05
491 3.283058322267607e-05
492 3.254006151109934e-05
493 3.1907922675600275e-05
494 3.1705902074463665e-05
495 3.147914321743883e-05
496 3.111985279247165e-05
497 3.079450470977463e-05
498 3.0240607884479687e-05
499 2.9828101105522364e-05
  • 简单的autograd
# Create tensors.
x = torch.tensor(1., requires_grad=True)
w = torch.tensor(2., requires_grad=True)
b = torch.tensor(3., requires_grad=True)

# Build a computational graph.
y = w * x + b    # y = 2 * x + 3

# Compute gradients.
y.backward()

# Print out the gradients.
print(x.grad)    # x.grad = 2 
print(w.grad)    # w.grad = 1 
print(b.grad)    # b.grad = 1 

print

PyTorch: Tensor和autograd

PyTorch的一个重要功能就是autograd,也就是说只要定义了forward pass(前向神经网络),计算了loss之后,PyTorch可以自动求导计算模型所有参数的梯度。

一个PyTorchTensor表示计算图中的一个节点。如果x是一个Tensor并且x.requires_grad=True那么x.grad是另一个储存着x当前梯度(相对于一个scalar,常常是loss)的向量。

import torch

dtype = torch.float
device = torch.device("cpu")
# device = torch.device("cuda:0") # Uncomment this to run on GPU

# N 是 batch size; D_in 是 input dimension;
# H 是 hidden dimension; D_out 是 output dimension.
N, D_in, H, D_out = 64, 1000, 100, 10

# 创建随机的Tensor来保存输入和输出
# 设定requires_grad=False表示在反向传播的时候我们不需要计算gradient
x = torch.randn(N, D_in, device=device, dtype=dtype)
y = torch.randn(N, D_out, device=device, dtype=dtype)

# 创建随机的Tensor和权重。
# 设置requires_grad=True表示我们希望反向传播的时候计算Tensor的gradient
w1 = torch.randn(D_in, H, device=device, dtype=dtype, requires_grad=True)
w2 = torch.randn(H, D_out, device=device, dtype=dtype, requires_grad=True)

learning_rate = 1e-6
for t in range(500):
    # 前向传播:通过Tensor预测y;这个和普通的神经网络的前向传播没有任何不同,
    # 但是我们不需要保存网络的中间运算结果,因为我们不需要手动计算反向传播。
    y_pred = x.mm(w1).clamp(min=0).mm(w2)

    # 通过前向传播计算loss
    # loss是一个形状为(1,)的Tensor
    # loss.item()可以给我们返回一个loss的scalar
    loss = (y_pred - y).pow(2).sum()
    print(t, loss.item())

    # PyTorch给我们提供了autograd的方法做反向传播。如果一个Tensor的requires_grad=True,
    # backward会自动计算loss相对于每个Tensor的gradient。在backward之后,
    # w1.grad和w2.grad会包含两个loss相对于两个Tensor的gradient信息。
    loss.backward()

    # 我们可以手动做gradient descent(后面我们会介绍自动的方法)。
    # 用torch.no_grad()包含以下statements,因为w1和w2都是requires_grad=True,
    # 但是在更新weights之后我们并不需要再做autograd。
    # 另一种方法是在weight.data和weight.grad.data上做操作,这样就不会对grad产生影响。
    # tensor.data会我们一个tensor,这个tensor和原来的tensor指向相同的内存空间,
    # 但是不会记录计算图的历史。
    with torch.no_grad():
        w1 -= learning_rate * w1.grad
        w2 -= learning_rate * w2.grad

        # Manually zero the gradients after updating weights
        w1.grad.zero_()
        w2.grad.zero_()

输出:

0 31590738.0
1 34389704.0
2 44504280.0
3 52598508.0
4 46752264.0
5 27227634.0
6 10779343.0
7 3889138.75
8 1856397.875
9 1232127.25
10 967278.5
11 806383.9375
12 687169.25
13 591936.25
14 513579.40625
15 448339.5
16 393390.71875
17 346772.71875
18 306952.625
19 272743.90625
20 243250.578125
21 217760.4375
22 195513.75
23 176012.4375
24 158848.59375
25 143694.4375
26 130272.53125
27 118357.1328125
28 107732.5625
29 98245.9296875
30 89754.4375
31 82145.9765625
32 75299.703125
33 69130.7265625
34 63549.09375
35 58498.18359375
36 53914.7421875
37 49751.984375
38 45963.8515625
39 42512.19140625
40 39364.1484375
41 36486.7421875
42 33852.94921875
43 31441.951171875
44 29230.11328125
45 27200.080078125
46 25335.595703125
47 23618.97265625
48 22036.193359375
49 20575.412109375
50 19227.5078125
51 17980.865234375
52 16826.919921875
53 15756.392578125
54 14762.513671875
55 13839.58203125
56 12981.9228515625
57 12184.3896484375
58 11442.140625
59 10750.8681640625
60 10106.751953125
61 9505.8720703125
62 8944.9736328125
63 8420.947265625
64 7931.27734375
65 7473.25
66 7044.6455078125
67 6643.35693359375
68 6267.51123046875
69 5915.16064453125
70 5584.55615234375
71 5274.3408203125
72 4983.08349609375
73 4709.5224609375
74 4452.46484375
75 4210.7333984375
76 3983.40234375
77 3769.482177734375
78 3568.060302734375
79 3378.383056640625
80 3199.61962890625
81 3031.24169921875
82 2872.541748046875
83 2722.83056640625
84 2581.67431640625
85 2448.48193359375
86 2322.647705078125
87 2203.833251953125
88 2091.537353515625
89 1985.4141845703125
90 1885.1256103515625
91 1790.2205810546875
92 1700.4173583984375
93 1615.4560546875
94 1535.025390625
95 1458.8707275390625
96 1386.747314453125
97 1318.4639892578125
98 1253.7381591796875
99 1192.38330078125
100 1134.2308349609375
101 1079.12255859375
102 1026.814697265625
103 977.2050170898438
104 930.1414184570312
105 885.4822998046875
106 843.100341796875
107 802.8572387695312
108 764.6590576171875
109 728.3578491210938
110 693.8634033203125
111 661.0936279296875
112 629.9700927734375
113 600.3693237304688
114 572.22705078125
115 545.47802734375
116 520.0562744140625
117 495.85888671875
118 472.8380432128906
119 450.9430236816406
120 430.1151428222656
121 410.2862854003906
122 391.4206237792969
123 373.4564514160156
124 356.3577880859375
125 340.0721130371094
126 324.5631103515625
127 309.78851318359375
128 295.7184143066406
129 282.31243896484375
130 269.5414733886719
131 257.3679504394531
132 245.76673889160156
133 234.71142578125
134 224.170166015625
135 214.1173553466797
136 204.53355407714844
137 195.39988708496094
138 186.68582153320312
139 178.3736114501953
140 170.44247436523438
141 162.87950134277344
142 155.6625518798828
143 148.77378845214844
144 142.2010955810547
145 135.93218994140625
146 129.94813537597656
147 124.23357391357422
148 118.77680969238281
149 113.56990051269531
150 108.5975341796875
151 103.852783203125
152 99.31798553466797
153 94.98828125
154 90.8508071899414
155 86.90143585205078
156 83.12843322753906
157 79.5234603881836
158 76.07766723632812
159 72.78606414794922
160 69.64219665527344
161 66.63894653320312
162 63.768882751464844
163 61.02363204956055
164 58.39965057373047
165 55.89255142211914
166 53.495147705078125
167 51.20370864868164
168 49.01222229003906
169 46.91761016845703
170 44.91386795043945
171 42.998748779296875
172 41.16694641113281
173 39.41572189331055
174 37.74125289916992
175 36.139198303222656
176 34.60701370239258
177 33.140785217285156
178 31.7379093170166
179 30.396512985229492
180 29.113128662109375
181 27.88433837890625
182 26.709943771362305
183 25.584423065185547
184 24.5084285736084
185 23.477943420410156
186 22.491390228271484
187 21.548294067382812
188 20.645343780517578
189 19.78101348876953
190 18.95306396484375
191 18.160476684570312
192 17.402193069458008
193 16.67613410949707
194 15.980989456176758
195 15.31508731842041
196 14.677319526672363
197 14.066808700561523
198 13.482386589050293
199 12.92292594909668
200 12.386455535888672
201 11.873013496398926
202 11.381402969360352
203 10.910120964050293
204 10.459120750427246
205 10.027090072631836
206 9.613011360168457
207 9.216401100158691
208 8.836578369140625
209 8.472214698791504
210 8.123515129089355
211 7.7894086837768555
212 7.469310760498047
213 7.162317752838135
214 6.868289470672607
215 6.586775779724121
216 6.316912651062012
217 6.058456897735596
218 5.810366630554199
219 5.572868824005127
220 5.345065116882324
221 5.1266560554504395
222 4.917425155639648
223 4.716660976409912
224 4.524556636810303
225 4.340378761291504
226 4.163658618927002
227 3.994236469268799
228 3.8317785263061523
229 3.6761226654052734
230 3.5268383026123047
231 3.383906126022339
232 3.2465603351593018
233 3.1150259971618652
234 2.988828420639038
235 2.8679425716400146
236 2.7518584728240967
237 2.640673875808716
238 2.5340168476104736
239 2.431691884994507
240 2.3335652351379395
241 2.2394731044769287
242 2.1492316722869873
243 2.0626134872436523
244 1.979622721672058
245 1.8999762535095215
246 1.823598861694336
247 1.7503103017807007
248 1.6799509525299072
249 1.6125677824020386
250 1.5478930473327637
251 1.4858715534210205
252 1.4262574911117554
253 1.3690857887268066
254 1.3144354820251465
255 1.2618528604507446
256 1.2113248109817505
257 1.1629149913787842
258 1.116491675376892
259 1.0719397068023682
260 1.0291515588760376
261 0.9881429672241211
262 0.948744535446167
263 0.9109649658203125
264 0.8747091293334961
265 0.8399428725242615
266 0.8065575957298279
267 0.7745365500450134
268 0.743747353553772
269 0.7142344117164612
270 0.6858413219451904
271 0.6586267352104187
272 0.6324704885482788
273 0.6074603796005249
274 0.5833799242973328
275 0.5603098273277283
276 0.538133442401886
277 0.5167998671531677
278 0.4963699281215668
279 0.47679442167282104
280 0.4579639434814453
281 0.4399140477180481
282 0.4225271940231323
283 0.4058454930782318
284 0.3898758292198181
285 0.37450703978538513
286 0.3596993088722229
287 0.34558868408203125
288 0.3319928050041199
289 0.3188936114311218
290 0.3063645660877228
291 0.2943674921989441
292 0.282818078994751
293 0.27166399359703064
294 0.26100337505340576
295 0.2508018910884857
296 0.24095848202705383
297 0.23150235414505005
298 0.22243773937225342
299 0.2137538194656372
300 0.20537863671779633
301 0.19736173748970032
302 0.18962986767292023
303 0.18225198984146118
304 0.1751105785369873
305 0.16825076937675476
306 0.16170242428779602
307 0.15539704263210297
308 0.1493482142686844
309 0.14353060722351074
310 0.13794031739234924
311 0.13258764147758484
312 0.1274181455373764
313 0.12247373908758163
314 0.11770471930503845
315 0.1131085455417633
316 0.10872947424650192
317 0.10449523478746414
318 0.10042322427034378
319 0.096539705991745
320 0.09280091524124146
321 0.08919163793325424
322 0.0857444480061531
323 0.08242055028676987
324 0.0792209729552269
325 0.07614120841026306
326 0.07320591062307358
327 0.07038237899541855
328 0.06766688823699951
329 0.06505295634269714
330 0.06254184246063232
331 0.06012542173266411
332 0.05779905244708061
333 0.05556876212358475
334 0.05343194305896759
335 0.05136372521519661
336 0.04940014332532883
337 0.047491997480392456
338 0.04566337913274765
339 0.043917763978242874
340 0.04224282503128052
341 0.04061662033200264
342 0.039049748331308365
343 0.037563201040029526
344 0.036115214228630066
345 0.03474092110991478
346 0.03340528532862663
347 0.0321282260119915
348 0.03089768998324871
349 0.02971574291586876
350 0.028576789423823357
351 0.02748893015086651
352 0.026447011157870293
353 0.025437500327825546
354 0.024477161467075348
355 0.02354368567466736
356 0.02265477180480957
357 0.021796781569719315
358 0.020970573648810387
359 0.02017371729016304
360 0.019415026530623436
361 0.01868613064289093
362 0.017972717061638832
363 0.017293401062488556
364 0.01664922386407852
365 0.01602707803249359
366 0.015429402701556683
367 0.014847123995423317
368 0.01428967621177435
369 0.013754935935139656
370 0.013242769055068493
371 0.012747152708470821
372 0.012274167500436306
373 0.011820298619568348
374 0.01137969084084034
375 0.01095337513834238
376 0.010552221909165382
377 0.010165980085730553
378 0.009792990982532501
379 0.009431964717805386
380 0.009089745581150055
381 0.0087546082213521
382 0.008431270718574524
383 0.008125036023557186
384 0.007833220064640045
385 0.007543003186583519
386 0.007273838389664888
387 0.007010471075773239
388 0.006761929951608181
389 0.006515162996947765
390 0.0062854718416929245
391 0.006053847726434469
392 0.005839650984853506
393 0.0056365784257650375
394 0.005430158693343401
395 0.005243257619440556
396 0.005058295093476772
397 0.0048800683580338955
398 0.004707938991487026
399 0.004541801754385233
400 0.004385354463011026
401 0.0042332010343670845
402 0.0040851193480193615
403 0.003942274488508701
404 0.003809330752119422
405 0.0036788880825042725
406 0.0035530496388673782
407 0.0034328829497098923
408 0.003316469956189394
409 0.0032058244105428457
410 0.003095718566328287
411 0.002996482653543353
412 0.002896404592320323
413 0.002801347989588976
414 0.0027062646113336086
415 0.0026161009445786476
416 0.002530781552195549
417 0.002449025632813573
418 0.002370838774368167
419 0.002294242149218917
420 0.002220114693045616
421 0.002151642693206668
422 0.0020829373970627785
423 0.0020190104842185974
424 0.0019563380628824234
425 0.0018947365460917354
426 0.0018343634437769651
427 0.0017779992194846272
428 0.0017241643508896232
429 0.001670036930590868
430 0.0016198739176616073
431 0.0015696510672569275
432 0.0015243508387356997
433 0.0014775173040106893
434 0.0014329483965411782
435 0.0013928760308772326
436 0.0013534030877053738
437 0.0013135093031451106
438 0.0012740943348035216
439 0.0012363230343908072
440 0.001201886567287147
441 0.001167844282463193
442 0.001134263351559639
443 0.0011019724188372493
444 0.0010723149171099067
445 0.0010419739410281181
446 0.0010126412380486727
447 0.0009848815388977528
448 0.0009563455241732299
449 0.0009308728040196002
450 0.00090641004499048
451 0.0008831481100060046
452 0.0008599660359323025
453 0.0008373564924113452
454 0.0008155326941050589
455 0.000794055056758225
456 0.0007728718337602913
457 0.0007523726671934128
458 0.0007338629802688956
459 0.0007152488688006997
460 0.000696428760420531
461 0.0006795382359996438
462 0.0006630202988162637
463 0.0006455725524574518
464 0.0006298987427726388
465 0.0006149400724098086
466 0.0006001737783662975
467 0.0005853470065630972
468 0.0005718616303056479
469 0.0005581608274951577
470 0.0005458852974697948
471 0.0005324622616171837
472 0.0005193008692003787
473 0.0005074077052995563
474 0.0004961146041750908
475 0.0004842482740059495
476 0.00047333139809779823
477 0.00046403566375374794
478 0.00045250554103404284
479 0.000442950171418488
480 0.0004327989590819925
481 0.00042354772449471056
482 0.000413733534514904
483 0.00040549712139181793
484 0.000397101161070168
485 0.00038828636752441525
486 0.0003802196588367224
487 0.0003723677946254611
488 0.00036461124545894563
489 0.0003568623214960098
490 0.0003494620032142848
491 0.00034172856248915195
492 0.00033552979584783316
493 0.0003281632671132684
494 0.00032163254218176007
495 0.00031514454167336226
496 0.00030899044941179454
497 0.0003025623445864767
498 0.0002974127419292927
499 0.00029148231260478497

PyTorch:nn

这次我们使用PyTorchnn这个库来构建网络。 用PyTorch autograd来构建计算图和计算gradients, 然后PyTorch会帮我们自动计算gradient

import torch

# N is batch size; D_in is input dimension;
# H is hidden dimension; D_out is output dimension.
N, D_in, H, D_out = 64, 1000, 100, 10

# Create random Tensors to hold inputs and outputs
x = torch.randn(N, D_in)
y = torch.randn(N, D_out)

# Use the nn package to define our model as a sequence of layers. nn.Sequential
# is a Module which contains other Modules, and applies them in sequence to
# produce its output. Each Linear Module computes output from input using a
# linear function, and holds internal Tensors for its weight and bias.
model = torch.nn.Sequential(
    torch.nn.Linear(D_in, H),
    torch.nn.ReLU(),
    torch.nn.Linear(H, D_out),
)

# The nn package also contains definitions of popular loss functions; in this
# case we will use Mean Squared Error (MSE) as our loss function.
loss_fn = torch.nn.MSELoss(reduction='sum')

learning_rate = 1e-4
for t in range(500):
    # Forward pass: compute predicted y by passing x to the model. Module objects
    # override the __call__ operator so you can call them like functions. When
    # doing so you pass a Tensor of input data to the Module and it produces
    # a Tensor of output data.
    y_pred = model(x)

    # Compute and print loss. We pass Tensors containing the predicted and true
    # values of y, and the loss function returns a Tensor containing the
    # loss.
    loss = loss_fn(y_pred, y)
    print(t, loss.item())

    # Zero the gradients before running the backward pass.
    model.zero_grad()

    # Backward pass: compute gradient of the loss with respect to all the learnable
    # parameters of the model. Internally, the parameters of each Module are stored
    # in Tensors with requires_grad=True, so this call will compute gradients for
    # all learnable parameters in the model.
    loss.backward()

    # Update the weights using gradient descent. Each parameter is a Tensor, so
    # we can access its gradients like we did before.
    with torch.no_grad():
        for param in model.parameters():
            param -= learning_rate * param.grad

输出:

0 616.8349609375
1 570.4186401367188
2 530.6421508789062
3 495.7164001464844
4 464.91497802734375
5 437.1092834472656
6 411.87066650390625
7 388.5781555175781
8 367.105224609375
9 347.06768798828125
10 328.3486328125
11 310.6429748535156
12 294.08880615234375
13 278.54046630859375
14 263.8558044433594
15 249.83802795410156
16 236.52313232421875
17 223.8170166015625
18 211.7015380859375
19 200.13755798339844
20 189.1465301513672
21 178.6802520751953
22 168.74122619628906
23 159.31674194335938
24 150.35125732421875
25 141.79025268554688
26 133.63401794433594
27 125.89380645751953
28 118.53340148925781
29 111.54275512695312
30 104.91582489013672
31 98.65790557861328
32 92.7421646118164
33 87.18020629882812
34 81.94192504882812
35 77.01036834716797
36 72.3639144897461
37 67.99095916748047
38 63.88977813720703
39 60.036468505859375
40 56.426231384277344
41 53.05012512207031
42 49.88925552368164
43 46.92338943481445
44 44.14652633666992
45 41.54481887817383
46 39.10710144042969
47 36.83108139038086
48 34.69960403442383
49 32.706153869628906
50 30.837730407714844
51 29.088411331176758
52 27.445863723754883
53 25.909109115600586
54 24.46878433227539
55 23.117572784423828
56 21.849687576293945
57 20.658483505249023
58 19.538114547729492
59 18.48484992980957
60 17.494647979736328
61 16.563005447387695
62 15.686758995056152
63 14.856837272644043
64 14.07450008392334
65 13.338075637817383
66 12.644659042358398
67 11.991114616394043
68 11.375090599060059
69 10.79391098022461
70 10.2453031539917
71 9.728269577026367
72 9.24122142791748
73 8.78144359588623
74 8.347025871276855
75 7.936522960662842
76 7.547698497772217
77 7.180155277252197
78 6.832812786102295
79 6.503701210021973
80 6.192763328552246
81 5.897329330444336
82 5.617427349090576
83 5.3520426750183105
84 5.100094795227051
85 4.860849380493164
86 4.634103298187256
87 4.419088363647461
88 4.2145843505859375
89 4.020376682281494
90 3.8357491493225098
91 3.660381317138672
92 3.493623733520508
93 3.3348333835601807
94 3.1840157508850098
95 3.040458917617798
96 2.90377140045166
97 2.7736332416534424
98 2.6498191356658936
99 2.531902313232422
100 2.419586658477783
101 2.312699317932129
102 2.2109007835388184
103 2.113172769546509
104 2.0200531482696533
105 1.931264042854309
106 1.8465205430984497
107 1.765866994857788
108 1.6889408826828003
109 1.6155205965042114
110 1.545505404472351
111 1.478671669960022
112 1.4148930311203003
113 1.354078769683838
114 1.2960002422332764
115 1.2405637502670288
116 1.1877013444900513
117 1.1373679637908936
118 1.0893242359161377
119 1.0433975458145142
120 0.9995244741439819
121 0.957534909248352
122 0.9174259901046753
123 0.8791078329086304
124 0.8424510359764099
125 0.8073628544807434
126 0.7738264203071594
127 0.7417508959770203
128 0.7110787630081177
129 0.6817481517791748
130 0.6536678075790405
131 0.6267886161804199
132 0.6010832786560059
133 0.5764719843864441
134 0.55291348695755
135 0.5303621292114258
136 0.508778989315033
137 0.4881012439727783
138 0.46830493211746216
139 0.4493491053581238
140 0.4311933219432831
141 0.41380244493484497
142 0.3971446752548218
143 0.3811851739883423
144 0.3658958375453949
145 0.35125595331192017
146 0.33723005652427673
147 0.3237687349319458
148 0.3108976483345032
149 0.2985706031322479
150 0.28673914074897766
151 0.27540862560272217
152 0.26455509662628174
153 0.2541574537754059
154 0.24418404698371887
155 0.23461730778217316
156 0.2254522293806076
157 0.2166563868522644
158 0.2082248032093048
159 0.20014120638370514
160 0.1923948973417282
161 0.18496574461460114
162 0.17784026265144348
163 0.17100077867507935
164 0.16445767879486084
165 0.15817493200302124
166 0.1521458625793457
167 0.14637260138988495
168 0.14082613587379456
169 0.13550066947937012
170 0.1303861439228058
171 0.12548044323921204
172 0.12077119201421738
173 0.11624855548143387
174 0.11190006136894226
175 0.10772517323493958
176 0.10371194034814835
177 0.0998566597700119
178 0.09615585952997208
179 0.09260128438472748
180 0.08918008953332901
181 0.08589056134223938
182 0.08273126929998398
183 0.07969754189252853
184 0.076775923371315
185 0.07396624237298965
186 0.07126504927873611
187 0.0686689168214798
188 0.06617266684770584
189 0.06377072632312775
190 0.06145996227860451
191 0.05924048647284508
192 0.057103026658296585
193 0.055055245757102966
194 0.05308816209435463
195 0.05119483917951584
196 0.04937274008989334
197 0.04762033745646477
198 0.04593187943100929
199 0.044306349009275436
200 0.04274118319153786
201 0.04123516380786896
202 0.03978639841079712
203 0.038389310240745544
204 0.03704371303319931
205 0.0357489138841629
206 0.03450245037674904
207 0.03330256789922714
208 0.0321439690887928
209 0.03102947771549225
210 0.02995586395263672
211 0.02892051823437214
212 0.02792329527437687
213 0.026961468160152435
214 0.02603522501885891
215 0.02514214999973774
216 0.02428179606795311
217 0.02345244586467743
218 0.02265304885804653
219 0.021882878616452217
220 0.02113926038146019
221 0.02042245678603649
222 0.019731855019927025
223 0.01906573213636875
224 0.018423302099108696
225 0.017803329974412918
226 0.017205584794282913
227 0.016628772020339966
228 0.016072314232587814
229 0.015535356476902962
230 0.015017733909189701
231 0.014518306590616703
232 0.014036747626960278
233 0.013571222312748432
234 0.013121938332915306
235 0.012688718736171722
236 0.012270272709429264
237 0.011866560205817223
238 0.011476823128759861
239 0.011100317351520061
240 0.010736910626292229
241 0.010386170819401741
242 0.0100497892126441
243 0.0097251171246171
244 0.009411685168743134
245 0.009108913131058216
246 0.008816596120595932
247 0.00853422749787569
248 0.008261235430836678
249 0.007997557520866394
250 0.007742919027805328
251 0.007496801670640707
252 0.007259095553308725
253 0.007029606495052576
254 0.006807629019021988
255 0.006593053694814444
256 0.006385627202689648
257 0.006185163743793964
258 0.005991355516016483
259 0.005803895648568869
260 0.005622652359306812
261 0.0054474519565701485
262 0.005278183612972498
263 0.005114411935210228
264 0.004955946933478117
265 0.004802867770195007
266 0.004654645454138517
267 0.004511530976742506
268 0.004373411647975445
269 0.00423995777964592
270 0.004110764712095261
271 0.003985690884292126
272 0.003864638740196824
273 0.0037474501878023148
274 0.0036340728402137756
275 0.0035244303289800882
276 0.0034184216056019068
277 0.0033156615681946278
278 0.0032162207644432783
279 0.003119939938187599
280 0.0030267902184277773
281 0.0029364691581577063
282 0.002849065698683262
283 0.00276445085182786
284 0.0026824434753507376
285 0.0026030712760984898
286 0.002526269294321537
287 0.002451810520142317
288 0.0023796753957867622
289 0.0023097810335457325
290 0.002242131158709526
291 0.002176533453166485
292 0.002112946705892682
293 0.002051342511549592
294 0.0019916510209441185
295 0.0019338340498507023
296 0.0018778254743665457
297 0.0018235259922221303
298 0.0017708562081679702
299 0.001719821011647582
300 0.001670365920290351
301 0.0016224351711571217
302 0.001575930742546916
303 0.0015308443689718843
304 0.0014871100429445505
305 0.0014447338180616498
306 0.0014036260545253754
307 0.0013637744123116136
308 0.0013251130003482103
309 0.00128761469386518
310 0.00125124619808048
311 0.0012159458128735423
312 0.0011817403137683868
313 0.0011485485592857003
314 0.0011163217714056373
315 0.0010850488906726241
316 0.0010547296842560172
317 0.0010253143263980746
318 0.0009967893129214644
319 0.0009690661099739373
320 0.0009421803988516331
321 0.0009160566260106862
322 0.0008907159208320081
323 0.0008661439060233533
324 0.0008422775426879525
325 0.0008190966327674687
326 0.0007965879631228745
327 0.0007747435010969639
328 0.000753526808694005
329 0.0007329185027629137
330 0.000712935405317694
331 0.0006935214041732252
332 0.0006746658473275602
333 0.0006563455681316555
334 0.0006385522428900003
335 0.000621272309217602
336 0.0006044972687959671
337 0.0005881927208974957
338 0.00057236134307459
339 0.0005569689092226326
340 0.0005420160596258938
341 0.000527493713889271
342 0.000513377774041146
343 0.0004996666684746742
344 0.0004863426147494465
345 0.00047339097363874316
346 0.0004608244926203042
347 0.0004486078687477857
348 0.0004367205547168851
349 0.0004251683712936938
350 0.0004139319353271276
351 0.0004030133131891489
352 0.00039240249316208065
353 0.00038208605838008225
354 0.00037205807166174054
355 0.0003623157390393317
356 0.0003528340021148324
357 0.00034362293081358075
358 0.0003346629673615098
359 0.0003259552177041769
360 0.0003174722078256309
361 0.00030922345467843115
362 0.0003012036031577736
363 0.00029340494074858725
364 0.00028583104722201824
365 0.000278460793197155
366 0.00027128090732730925
367 0.00026430474827066064
368 0.0002575131948105991
369 0.0002509095938876271
370 0.00024448230396956205
371 0.00023822381626814604
372 0.0002321432693861425
373 0.00022622810502070934
374 0.0002204657648690045
375 0.0002148632047465071
376 0.00020940121612511575
377 0.00020409838180057704
378 0.0001989272132050246
379 0.00019389843509998173
380 0.00018900231225416064
381 0.0001842392230173573
382 0.00017960301192943007
383 0.00017508988094050437
384 0.00017069902969524264
385 0.00016641429101582617
386 0.00016225305444095284
387 0.000158198265125975
388 0.00015424926823470742
389 0.000150404914165847
390 0.0001466635148972273
391 0.00014301779447123408
392 0.00013946628314442933
393 0.00013601550017483532
394 0.00013265143206808716
395 0.00012937198334839195
396 0.00012617645552381873
397 0.0001230676716659218
398 0.0001200390252051875
399 0.00011708753299899399
400 0.00011421682575019076
401 0.00011141804861836135
402 0.00010869121615542099
403 0.00010603333066683263
404 0.00010344652400817722
405 0.00010092253796756268
406 9.846295870374888e-05
407 9.6071045845747e-05
408 9.373520879307762e-05
409 9.146131196757779e-05
410 8.924482244765386e-05
411 8.70852600201033e-05
412 8.498283568769693e-05
413 8.293260907521471e-05
414 8.093572250800207e-05
415 7.89878613431938e-05
416 7.708741031819955e-05
417 7.523671229137108e-05
418 7.34319764887914e-05
419 7.167273724917322e-05
420 6.995951844146475e-05
421 6.828828190919012e-05
422 6.665829278063029e-05
423 6.507106445496902e-05
424 6.35203905403614e-05
425 6.200941425049677e-05
426 6.05389905103948e-05
427 5.9100995713379234e-05
428 5.7702251069713384e-05
429 5.633640830637887e-05
430 5.500561746885069e-05
431 5.3707120969193056e-05
432 5.2441897423705086e-05
433 5.1206407079007477e-05
434 5.000238525099121e-05
435 4.882620123680681e-05
436 4.76813547720667e-05
437 4.656398596125655e-05
438 4.5474393118638545e-05
439 4.4412216084310785e-05
440 4.3375530367484316e-05
441 4.236115637468174e-05
442 4.137431824347004e-05
443 4.0410945075564086e-05
444 3.94726412196178e-05
445 3.8555674109375104e-05
446 3.765980727621354e-05
447 3.6788511351915076e-05
448 3.593677683966234e-05
449 3.510593160171993e-05
450 3.4294465876882896e-05
451 3.350316546857357e-05
452 3.273158290539868e-05
453 3.197851401637308e-05
454 3.124314753222279e-05
455 3.0524690373567864e-05
456 2.982511796290055e-05
457 2.9140464903321117e-05
458 2.8473716156440787e-05
459 2.7822752599604428e-05
460 2.718799078138545e-05
461 2.656610740814358e-05
462 2.5960111088352278e-05
463 2.536940883146599e-05
464 2.479194699844811e-05
465 2.4228789698099717e-05
466 2.367901470279321e-05
467 2.314150333404541e-05
468 2.2617272406932898e-05
469 2.2104968593339436e-05
470 2.1604977519018576e-05
471 2.111716639774386e-05
472 2.0640816728700884e-05
473 2.0175557438051328e-05
474 1.9721246644621715e-05
475 1.92775023606373e-05
476 1.8843960788217373e-05
477 1.8419739717501216e-05
478 1.8007833205047064e-05
479 1.760409759299364e-05
480 1.7209487850777805e-05
481 1.6824411432025954e-05
482 1.6448399037471972e-05
483 1.608097954886034e-05
484 1.5722764146630652e-05
485 1.537234493298456e-05
486 1.5029500900709536e-05
487 1.4695562640554272e-05
488 1.4368548363563605e-05
489 1.4049977835384198e-05
490 1.3737889275944326e-05
491 1.3433097592496779e-05
492 1.3135630069882609e-05
493 1.2845604032918345e-05
494 1.256068117072573e-05
495 1.2283581781957764e-05
496 1.2012886145384982e-05
497 1.1748711585823912e-05
498 1.1490279575809836e-05
499 1.1236619684495963e-05

PyTorch:optim

这一次我们不再手动更新模型的weights,而是使用optim这个包来帮助我们更新参数。 optim这个package提供了各种不同的模型优化方法,包括SGD+momentum, RMSProp, Adam等等。

import torch

# N is batch size; D_in is input dimension;
# H is hidden dimension; D_out is output dimension.
N, D_in, H, D_out = 64, 1000, 100, 10

# Create random Tensors to hold inputs and outputs
x = torch.randn(N, D_in)
y = torch.randn(N, D_out)

# Use the nn package to define our model and loss function.
model = torch.nn.Sequential(
    torch.nn.Linear(D_in, H),
    torch.nn.ReLU(),
    torch.nn.Linear(H, D_out),
)
loss_fn = torch.nn.MSELoss(reduction='sum')

# Use the optim package to define an Optimizer that will update the weights of
# the model for us. Here we will use Adam; the optim package contains many other
# optimization algoriths. The first argument to the Adam constructor tells the
# optimizer which Tensors it should update.
learning_rate = 1e-4
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
for t in range(500):
    # Forward pass: compute predicted y by passing x to the model.
    y_pred = model(x)

    # Compute and print loss.
    loss = loss_fn(y_pred, y)
    print(t, loss.item())

    # Before the backward pass, use the optimizer object to zero all of the
    # gradients for the variables it will update (which are the learnable
    # weights of the model). This is because by default, gradients are
    # accumulated in buffers( i.e, not overwritten) whenever .backward()
    # is called. Checkout docs of torch.autograd.backward for more details.
    optimizer.zero_grad()

    # Backward pass: compute gradient of the loss with respect to model
    # parameters
    loss.backward()

    # Calling the step function on an Optimizer makes an update to its
    # parameters
    optimizer.step()

输出:

0 791.6784057617188
1 772.9479370117188
2 754.7570190429688
3 737.1196899414062
4 719.9917602539062
5 703.4118041992188
6 687.2720947265625
7 671.5335083007812
8 656.197021484375
9 641.322265625
10 626.919189453125
11 612.9500732421875
12 599.3975219726562
13 586.327392578125
14 573.5608520507812
15 561.1778564453125
16 549.1342163085938
17 537.3661499023438
18 525.8930053710938
19 514.7115478515625
20 503.76434326171875
21 493.110107421875
22 482.747802734375
23 472.5677490234375
24 462.6549072265625
25 452.9748840332031
26 443.5308532714844
27 434.2815856933594
28 425.22064208984375
29 416.3988342285156
30 407.7718505859375
31 399.32879638671875
32 391.1062927246094
33 383.07781982421875
34 375.2408752441406
35 367.54071044921875
36 359.98126220703125
37 352.5944519042969
38 345.3690185546875
39 338.3002624511719
40 331.3758544921875
41 324.56463623046875
42 317.8782958984375
43 311.31622314453125
44 304.8724670410156
45 298.5438537597656
46 292.3463439941406
47 286.2505187988281
48 280.26092529296875
49 274.3931884765625
50 268.6314697265625
51 262.9659423828125
52 257.3974914550781
53 251.95529174804688
54 246.6234893798828
55 241.38490295410156
56 236.2388153076172
57 231.19229125976562
58 226.2274627685547
59 221.34634399414062
60 216.5567169189453
61 211.8523406982422
62 207.2224884033203
63 202.67787170410156
64 198.20230102539062
65 193.80657958984375
66 189.49249267578125
67 185.26438903808594
68 181.1260223388672
69 177.0719451904297
70 173.08523559570312
71 169.17050170898438
72 165.3196258544922
73 161.5393524169922
74 157.82603454589844
75 154.1715545654297
76 150.57814025878906
77 147.04661560058594
78 143.57847595214844
79 140.1695098876953
80 136.83058166503906
81 133.5507049560547
82 130.3309326171875
83 127.16106414794922
84 124.03736877441406
85 120.97663116455078
86 117.97151947021484
87 115.01973724365234
88 112.12226104736328
89 109.28099822998047
90 106.49259948730469
91 103.7560806274414
92 101.0788345336914
93 98.46456146240234
94 95.8936538696289
95 93.37211608886719
96 90.9074478149414
97 88.49383544921875
98 86.1294174194336
99 83.81201171875
100 81.54666137695312
101 79.3274917602539
102 77.15277862548828
103 75.02595520019531
104 72.94225311279297
105 70.90507507324219
106 68.91564178466797
107 66.96821594238281
108 65.06224822998047
109 63.201744079589844
110 61.382633209228516
111 59.60671615600586
112 57.873958587646484
113 56.179588317871094
114 54.523807525634766
115 52.906803131103516
116 51.32786178588867
117 49.788665771484375
118 48.285526275634766
119 46.82194137573242
120 45.393489837646484
121 44.00169372558594
122 42.64366149902344
123 41.320953369140625
124 40.03276824951172
125 38.77764892578125
126 37.554481506347656
127 36.362998962402344
128 35.20222854614258
129 34.07368850708008
130 32.9746208190918
131 31.906068801879883
132 30.866662979125977
133 29.854427337646484
134 28.871971130371094
135 27.915695190429688
136 26.98732566833496
137 26.086950302124023
138 25.21206283569336
139 24.362747192382812
140 23.53691291809082
141 22.73594093322754
142 21.958166122436523
143 21.20453453063965
144 20.472272872924805
145 19.762775421142578
146 19.074495315551758
147 18.407861709594727
148 17.762025833129883
149 17.135404586791992
150 16.52802085876465
151 15.940832138061523
152 15.3715238571167
153 14.819578170776367
154 14.285402297973633
155 13.767573356628418
156 13.266966819763184
157 12.782539367675781
158 12.314498901367188
159 11.861427307128906
160 11.422765731811523
161 10.999663352966309
162 10.590702056884766
163 10.195518493652344
164 9.813480377197266
165 9.444453239440918
166 9.088181495666504
167 8.743974685668945
168 8.411545753479004
169 8.090726852416992
170 7.781579494476318
171 7.482759475708008
172 7.194526672363281
173 6.916567325592041
174 6.648494243621826
175 6.389762878417969
176 6.139918327331543
177 5.899596214294434
178 5.667441368103027
179 5.443691253662109
180 5.22829008102417
181 5.020461559295654
182 4.8205108642578125
183 4.627894401550293
184 4.442294120788574
185 4.263693332672119
186 4.091548919677734
187 3.9257419109344482
188 3.766235828399658
189 3.6124517917633057
190 3.4642839431762695
191 3.3219587802886963
192 3.18510103225708
193 3.0533838272094727
194 2.926755905151367
195 2.8052432537078857
196 2.688368082046509
197 2.576014757156372
198 2.4679884910583496
199 2.3642969131469727
200 2.264831066131592
201 2.16910457611084
202 2.077399253845215
203 1.9893807172775269
204 1.9050089120864868
205 1.8239641189575195
206 1.746296763420105
207 1.6716521978378296
208 1.6001081466674805
209 1.531592607498169
210 1.4656728506088257
211 1.4025815725326538
212 1.34219491481781
213 1.2842515707015991
214 1.2286741733551025
215 1.175434947013855
216 1.1243504285812378
217 1.0754128694534302
218 1.028563141822815
219 0.9836430549621582
220 0.9406015276908875
221 0.899389386177063
222 0.8599286079406738
223 0.8221437335014343
224 0.7859696745872498
225 0.7512990832328796
226 0.7181407809257507
227 0.686427891254425
228 0.6560250520706177
229 0.6269176602363586
230 0.5990850329399109
231 0.5724562406539917
232 0.5469314455986023
233 0.5225721001625061
234 0.4992505609989166
235 0.4769216477870941
236 0.4555879533290863
237 0.4351666569709778
238 0.41567420959472656
239 0.3969923257827759
240 0.3791610300540924
241 0.3620989918708801
242 0.3457787036895752
243 0.3301834166049957
244 0.3152735233306885
245 0.30102694034576416
246 0.2874011993408203
247 0.2743793725967407
248 0.26194536685943604
249 0.2500665485858917
250 0.23870456218719482
251 0.2278587371110916
252 0.21748881042003632
253 0.2075870931148529
254 0.19812825322151184
255 0.18910430371761322
256 0.18047408759593964
257 0.1722312867641449
258 0.1643613874912262
259 0.15685199201107025
260 0.1496802419424057
261 0.14283005893230438
262 0.13628333806991577
263 0.13004115223884583
264 0.12407437711954117
265 0.11838021129369736
266 0.11294573545455933
267 0.10776441544294357
268 0.10281281918287277
269 0.09808588773012161
270 0.09357728809118271
271 0.08927135169506073
272 0.08516616374254227
273 0.08124646544456482
274 0.07750718295574188
275 0.07393348217010498
276 0.07052432745695114
277 0.06727367639541626
278 0.06417549401521683
279 0.061211828142404556
280 0.05838524177670479
281 0.055691372603178024
282 0.05311958119273186
283 0.050665080547332764
284 0.04832402989268303
285 0.046091243624687195
286 0.04395892843604088
287 0.04192692041397095
288 0.03998580947518349
289 0.0381360799074173
290 0.03637010604143143
291 0.034684911370277405
292 0.0330788791179657
293 0.03154633194208145
294 0.03008444234728813
295 0.028689226135611534
296 0.027359401807188988
297 0.02609046921133995
298 0.02487858757376671
299 0.023724785074591637
300 0.02262282185256481
301 0.0215724129229784
302 0.020570827648043633
303 0.01961461454629898
304 0.018704809248447418
305 0.017835130915045738
306 0.01700596511363983
307 0.016215775161981583
308 0.015460998751223087
309 0.014741620048880577
310 0.014055773615837097
311 0.013400538824498653
312 0.01277637854218483
313 0.012180116027593613
314 0.011613084003329277
315 0.011071483604609966
316 0.010555664077401161
317 0.010063034482300282
318 0.009593896567821503
319 0.009146292693912983
320 0.00871907826513052
321 0.008312968537211418
322 0.0079236114397645
323 0.007553379982709885
324 0.007200181949883699
325 0.006863656919449568
326 0.00654265284538269
327 0.006236561574041843
328 0.005944761913269758
329 0.005666198208928108
330 0.005400847643613815
331 0.0051480308175086975
332 0.0049063521437346935
333 0.004676346201449633
334 0.004457048140466213
335 0.00424786563962698
336 0.00404843594878912
337 0.003858234267681837
338 0.0036770079750567675
339 0.003504228312522173
340 0.003339442191645503
341 0.0031823033932596445
342 0.003032673615962267
343 0.00288983853533864
344 0.0027537832502275705
345 0.0026240183506160975
346 0.0025004481431096792
347 0.002382527105510235
348 0.0022701220586895943
349 0.002162923803552985
350 0.002060842467471957
351 0.001963510410860181
352 0.0018707378767430782
353 0.0017822845838963985
354 0.001697997679002583
355 0.0016178140649572015
356 0.0015412119682878256
357 0.0014681273605674505
358 0.0013986037811264396
359 0.0013322994345799088
360 0.0012691081501543522
361 0.0012088974472135305
362 0.001151501783169806
363 0.001096819993108511
364 0.0010446718661114573
365 0.0009950095554813743
366 0.0009476402192376554
367 0.000902548257727176
368 0.0008595681865699589
369 0.0008186014601960778
370 0.0007795642595738173
371 0.0007423567585647106
372 0.0007069373968988657
373 0.0006732249748893082
374 0.0006410047644749284
375 0.0006103675113990903
376 0.0005811756127513945
377 0.0005533915827982128
378 0.0005268672248348594
379 0.0005016201175749302
380 0.00047757592983543873
381 0.0004546679265331477
382 0.0004328518407419324
383 0.0004121344827581197
384 0.0003922642790712416
385 0.00037340185372158885
386 0.000355449941707775
387 0.0003383288567420095
388 0.00032203507726080716
389 0.000306513044051826
390 0.0002917253877967596
391 0.0002776390756480396
392 0.0002642270701471716
393 0.0002514584339223802
394 0.00023929811140988022
395 0.0002277194580528885
396 0.0002166864142054692
397 0.00020619173301383853
398 0.0001961870730156079
399 0.0001866663369582966
400 0.0001776068238541484
401 0.00016897440946195275
402 0.00016075365419965237
403 0.00015293085016310215
404 0.0001454944722354412
405 0.00013839844905305654
406 0.0001316592242801562
407 0.00012522179167717695
408 0.00011911398178199306
409 0.00011329452536301687
410 0.0001077575288945809
411 0.00010248417675029486
412 9.74709982983768e-05
413 9.269234578823671e-05
414 8.815102046355605e-05
415 8.382758096558973e-05
416 7.971279410412535e-05
417 7.579699740745127e-05
418 7.207059388747439e-05
419 6.852364458609372e-05
420 6.515011045848951e-05
421 6.194211891852319e-05
422 5.8886504120891914e-05
423 5.598169445875101e-05
424 5.322420838638209e-05
425 5.058886745246127e-05
426 4.808947414858267e-05
427 4.571105819195509e-05
428 4.344765329733491e-05
429 4.1297284042229876e-05
430 3.9249673136509955e-05
431 3.7301993870642036e-05
432 3.545052823028527e-05
433 3.3691019780235365e-05
434 3.2015552278608084e-05
435 3.0420967959798872e-05
436 2.890893301810138e-05
437 2.74712383543374e-05
438 2.6098401576746255e-05
439 2.4797802325338125e-05
440 2.3561869966215454e-05
441 2.2385444026440382e-05
442 2.1267582269501872e-05
443 2.0203089661663398e-05
444 1.9192844774806872e-05
445 1.8230841305921786e-05
446 1.7318037862423807e-05
447 1.6449723261757754e-05
448 1.562457691761665e-05
449 1.4840144103800412e-05
450 1.409533797414042e-05
451 1.3387114449869841e-05
452 1.2712825991911814e-05
453 1.2073536709067412e-05
454 1.1465210263850167e-05
455 1.0887116332014557e-05
456 1.0337735147913918e-05
457 9.81609719019616e-06
458 9.320682693214621e-06
459 8.849255209497642e-06
460 8.402344064961653e-06
461 7.977385394042358e-06
462 7.57272437112988e-06
463 7.188868949015159e-06
464 6.824670890637208e-06
465 6.479081548604881e-06
466 6.150191438791808e-06
467 5.837410753883887e-06
468 5.5405430430255365e-06
469 5.259076715447009e-06
470 4.991058631276246e-06
471 4.736447863251669e-06
472 4.495564553508302e-06
473 4.266698852006812e-06
474 4.049014023621567e-06
475 3.842098976747366e-06
476 3.6463525248109363e-06
477 3.459550043771742e-06
478 3.2829811971168965e-06
479 3.114487753919093e-06
480 2.955280706373742e-06
481 2.804437144732219e-06
482 2.6597372198011726e-06
483 2.523453986214008e-06
484 2.3942884581629187e-06
485 2.271245421070489e-06
486 2.1546265998040326e-06
487 2.0439254058146616e-06
488 1.938894001796143e-06
489 1.8390732066109194e-06
490 1.7443379647374968e-06
491 1.654603806855448e-06
492 1.569268874845875e-06
493 1.4882393770676572e-06
494 1.4111951713857707e-06
495 1.3382320958044147e-06
496 1.269111749024887e-06
497 1.2034716974085313e-06
498 1.1412265621402184e-06
499 1.0824967375810957e-06

PyTorch: 自定义 nn Modules

我们可以定义一个模型,这个模型继承自nn.Module类。如果需要定义一个比Sequential模型更加复杂的模型,就需要定义nn.Module模型。

import torch


class TwoLayerNet(torch.nn.Module):
    def __init__(self, D_in, H, D_out):
        """
        In the constructor we instantiate two nn.Linear modules and assign them as
        member variables.
        """
        super(TwoLayerNet, self).__init__()
        self.linear1 = torch.nn.Linear(D_in, H)
        self.linear2 = torch.nn.Linear(H, D_out)

    def forward(self, x):
        """
        In the forward function we accept a Tensor of input data and we must return
        a Tensor of output data. We can use Modules defined in the constructor as
        well as arbitrary operators on Tensors.
        """
        h_relu = self.linear1(x).clamp(min=0)
        y_pred = self.linear2(h_relu)
        return y_pred


# N is batch size; D_in is input dimension;
# H is hidden dimension; D_out is output dimension.
N, D_in, H, D_out = 64, 1000, 100, 10

# Create random Tensors to hold inputs and outputs
x = torch.randn(N, D_in)
y = torch.randn(N, D_out)

# Construct our model by instantiating the class defined above
model = TwoLayerNet(D_in, H, D_out)

# Construct our loss function and an Optimizer. The call to model.parameters()
# in the SGD constructor will contain the learnable parameters of the two
# nn.Linear modules which are members of the model.
criterion = torch.nn.MSELoss(reduction='sum')
optimizer = torch.optim.SGD(model.parameters(), lr=1e-4)
for t in range(500):
    # Forward pass: Compute predicted y by passing x to the model
    y_pred = model(x)

    # Compute and print loss
    loss = criterion(y_pred, y)
    print(t, loss.item())

    # Zero gradients, perform a backward pass, and update the weights.
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

输出:

0 656.6958618164062
1 608.1090087890625
2 566.172607421875
3 529.2335815429688
4 496.7382507324219
5 467.453125
6 440.5755310058594
7 416.12872314453125
8 393.6068420410156
9 372.708251953125
10 353.00006103515625
11 334.477783203125
12 316.97283935546875
13 300.36737060546875
14 284.6544189453125
15 269.65936279296875
16 255.33456420898438
17 241.66688537597656
18 228.60800170898438
19 216.09536743164062
20 204.13780212402344
21 192.75645446777344
22 181.89234924316406
23 171.58370971679688
24 161.7939453125
25 152.4780731201172
26 143.59371948242188
27 135.14727783203125
28 127.13992309570312
29 119.55585479736328
30 112.37797546386719
31 105.62073516845703
32 99.24383544921875
33 93.24134826660156
34 87.58341979980469
35 82.25212860107422
36 77.24210357666016
37 72.55087280273438
38 68.1427230834961
39 64.00277709960938
40 60.1308479309082
41 56.49887466430664
42 53.0952033996582
43 49.906524658203125
44 46.91959762573242
45 44.11970520019531
46 41.50297164916992
47 39.062870025634766
48 36.77130126953125
49 34.62575149536133
50 32.61510467529297
51 30.732582092285156
52 28.968435287475586
53 27.313982009887695
54 25.76266098022461
55 24.30921173095703
56 22.946903228759766
57 21.669231414794922
58 20.470420837402344
59 19.347976684570312
60 18.293415069580078
61 17.30330467224121
62 16.373027801513672
63 15.501211166381836
64 14.681639671325684
65 13.909913063049316
66 13.183857917785645
67 12.499263763427734
68 11.854406356811523
69 11.247345924377441
70 10.675314903259277
71 10.13557243347168
72 9.626264572143555
73 9.145630836486816
74 8.692134857177734
75 8.26422119140625
76 7.859929084777832
77 7.478569030761719
78 7.118776321411133
79 6.778269290924072
80 6.455151557922363
81 6.149427890777588
82 5.859961986541748
83 5.586189270019531
84 5.326669692993164
85 5.081073760986328
86 4.848598003387451
87 4.628039836883545
88 4.418812274932861
89 4.22011137008667
90 4.03148889541626
91 3.85233211517334
92 3.6820061206817627
93 3.5200538635253906
94 3.366295576095581
95 3.220104694366455
96 3.081094980239868
97 2.9478847980499268
98 2.8211557865142822
99 2.700608491897583
100 2.585635185241699
101 2.476266860961914
102 2.3720734119415283
103 2.27278733253479
104 2.178467273712158
105 2.0886454582214355
106 2.0029616355895996
107 1.9212790727615356
108 1.8433643579483032
109 1.7690125703811646
110 1.6980165243148804
111 1.630236029624939
112 1.5654228925704956
113 1.5034924745559692
114 1.444340467453003
115 1.387866735458374
116 1.3338878154754639
117 1.2822437286376953
118 1.232841968536377
119 1.1855518817901611
120 1.1402850151062012
121 1.0969470739364624
122 1.0554102659225464
123 1.0156100988388062
124 0.9774888753890991
125 0.9409765601158142
126 0.9060029983520508
127 0.8724609017372131
128 0.8402785062789917
129 0.8092858195304871
130 0.7795636653900146
131 0.751034140586853
132 0.7236625552177429
133 0.697391927242279
134 0.6721824407577515
135 0.6479606628417969
136 0.6246941685676575
137 0.6023442149162292
138 0.5808628797531128
139 0.5602169036865234
140 0.5403794050216675
141 0.5213167071342468
142 0.5029921531677246
143 0.48536738753318787
144 0.4684107005596161
145 0.4521111249923706
146 0.4364219009876251
147 0.4213317036628723
148 0.4068053662776947
149 0.3928261399269104
150 0.3793690800666809
151 0.3664127290248871
152 0.35393786430358887
153 0.34193798899650574
154 0.33036914467811584
155 0.3192187547683716
156 0.3084754943847656
157 0.2981330454349518
158 0.28816646337509155
159 0.2785727083683014
160 0.26933375000953674
161 0.2604302763938904
162 0.2518427073955536
163 0.2435620129108429
164 0.2355814278125763
165 0.22787995636463165
166 0.22044798731803894
167 0.2132810354232788
168 0.20636311173439026
169 0.19969020783901215
170 0.19325482845306396
171 0.18703718483448029
172 0.18104007840156555
173 0.17524453997612
174 0.16964828968048096
175 0.16424523293972015
176 0.1590285748243332
177 0.15399070084095
178 0.14912430942058563
179 0.14442437887191772
180 0.13988198339939117
181 0.13549421727657318
182 0.13124999403953552
183 0.12715040147304535
184 0.12318697571754456
185 0.11935491114854813
186 0.11565019190311432
187 0.11206966638565063
188 0.10860667377710342
189 0.10525806993246078
190 0.10202305018901825
191 0.09889409691095352
192 0.09586735814809799
193 0.09293802082538605
194 0.09010337293148041
195 0.0873604416847229
196 0.0847071036696434
197 0.08213980495929718
198 0.07965682446956635
199 0.07725474238395691
200 0.07492762058973312
201 0.07267512381076813
202 0.07049493491649628
203 0.06838559359312057
204 0.06634149700403214
205 0.06436219811439514
206 0.062446292489767075
207 0.06059153750538826
208 0.0587935708463192
209 0.05705287307500839
210 0.05536716431379318
211 0.05373508855700493
212 0.052152544260025024
213 0.05062079429626465
214 0.049136947840452194
215 0.04769892245531082
216 0.0463058203458786
217 0.04495447129011154
218 0.043645307421684265
219 0.04237687215209007
220 0.04114643111824989
221 0.03995450586080551
222 0.03880004957318306
223 0.03768027573823929
224 0.03659489378333092
225 0.03554360195994377
226 0.03452374413609505
227 0.03353426977992058
228 0.032575368881225586
229 0.03164541721343994
230 0.03074309229850769
231 0.02986789681017399
232 0.029019085690379143
233 0.028195498511195183
234 0.0273965485394001
235 0.02662237547338009
236 0.025871429592370987
237 0.025142081081867218
238 0.024434441700577736
239 0.023747552186250687
240 0.023081056773662567
241 0.022434504702687263
242 0.021806789562106133
243 0.021197138354182243
244 0.020606014877557755
245 0.02003205008804798
246 0.019474638625979424
247 0.018934227526187897
248 0.01840922422707081
249 0.01789930835366249
250 0.017403993755578995
251 0.016923150047659874
252 0.016456523910164833
253 0.0160031970590353
254 0.015562880784273148
255 0.01513546984642744
256 0.014720354229211807
257 0.01431703194975853
258 0.013925755396485329
259 0.013545180670917034
260 0.013175630941987038
261 0.01281663216650486
262 0.012467730790376663
263 0.012128635309636593
264 0.011799173429608345
265 0.011479041539132595
266 0.011168107390403748
267 0.01086602546274662
268 0.01057237759232521
269 0.010287342593073845
270 0.010010016150772572
271 0.009740477427840233
272 0.009478610940277576
273 0.009224051609635353
274 0.008976546116173267
275 0.00873596128076315
276 0.008502164855599403
277 0.008274776861071587
278 0.00805367436259985
279 0.007838784717023373
280 0.007630025502294302
281 0.007427022326737642
282 0.0072296857833862305
283 0.007037700153887272
284 0.006851076614111662
285 0.00666955066844821
286 0.006492926273494959
287 0.00632116524502635
288 0.006154042202979326
289 0.005991632118821144
290 0.0058336709626019
291 0.0056800455786287785
292 0.005530708469450474
293 0.005385328084230423
294 0.0052438536658883095
295 0.005106212105602026
296 0.004972323775291443
297 0.004842082504183054
298 0.004715400282293558
299 0.004592181649059057
300 0.004472256172448397
301 0.004355618264526129
302 0.0042420197278261185
303 0.004131658002734184
304 0.004024124704301357
305 0.003919483628123999
306 0.003817657707259059
307 0.0037186346016824245
308 0.003622243879362941
309 0.003528367727994919
310 0.0034370282664895058
311 0.0033481812570244074
312 0.0032616364769637585
313 0.0031774123199284077
314 0.00309544475749135
315 0.003015762660652399
316 0.002938046120107174
317 0.002862409455701709
318 0.002788822166621685
319 0.0027171699330210686
320 0.0026473761536180973
321 0.002579432912170887
322 0.002513323212042451
323 0.0024489827919751406
324 0.0023862975649535656
325 0.0023252193350344896
326 0.0022658223751932383
327 0.0022080307826399803
328 0.0021516724955290556
329 0.002096814103424549
330 0.002043382963165641
331 0.0019913306459784508
332 0.0019406524952501059
333 0.0018913011299446225
334 0.0018432465149089694
335 0.0017964182188734412
336 0.0017508040182292461
337 0.001706396578811109
338 0.001663187169469893
339 0.0016211027977988124
340 0.0015800849068909883
341 0.0015401256969198585
342 0.0015011945506557822
343 0.001463254913687706
344 0.0014263042248785496
345 0.001390322926454246
346 0.0013552795862779021
347 0.0013211170444265008
348 0.0012878385605290532
349 0.0012554213171824813
350 0.0012238684576004744
351 0.0011931274784728885
352 0.0011631487868726254
353 0.0011339503107592463
354 0.0011054981732740998
355 0.0010777906281873584
356 0.0010507676051929593
357 0.0010244473814964294
358 0.0009987998055294156
359 0.00097382947569713
360 0.0009494827827438712
361 0.0009257435449399054
362 0.000902607396710664
363 0.0008801089134067297
364 0.0008581369183957577
365 0.0008367350674234331
366 0.0008158780983649194
367 0.0007955483160912991
368 0.0007757404237054288
369 0.0007564350962638855
370 0.000737626978661865
371 0.0007192868506535888
372 0.0007014275179244578
373 0.0006839964189566672
374 0.0006670261500403285
375 0.0006505012279376388
376 0.0006343786371871829
377 0.0006186614627949893
378 0.0006033276440575719
379 0.0005883832345716655
380 0.0005738206673413515
381 0.0005596213741227984
382 0.0005457888473756611
383 0.0005322962533682585
384 0.0005191444652155042
385 0.000506328884512186
386 0.0004938290221616626
387 0.00048163760220631957
388 0.0004697689728345722
389 0.0004582055553328246
390 0.00044691533548757434
391 0.00043590739369392395
392 0.0004251690406817943
393 0.0004147063591517508
394 0.00040450665983371437
395 0.0003945553908124566
396 0.0003848606429528445
397 0.00037539892946369946
398 0.0003661849768832326
399 0.00035720854066312313
400 0.000348439411027357
401 0.0003398970584385097
402 0.00033156739664264023
403 0.0003234421892557293
404 0.0003155224258080125
405 0.00030779733788222075
406 0.0003002593875862658
407 0.00029291390092112124
408 0.00028574312455020845
409 0.0002787590492516756
410 0.000271946337306872
411 0.00026530082686804235
412 0.00025882109184749424
413 0.0002525025047361851
414 0.00024633959401398897
415 0.00024033462977968156
416 0.00023447422427125275
417 0.00022875398281030357
418 0.00022317911498248577
419 0.00021774203923996538
420 0.00021243662922643125
421 0.00020726659568026662
422 0.0002022238913923502
423 0.00019730582425836474
424 0.00019250763580203056
425 0.0001878279581433162
426 0.00018326277495361865
427 0.00017881377425510436
428 0.0001744656328810379
429 0.00017023469263222069
430 0.0001660993875702843
431 0.0001620692346477881
432 0.00015813524078112096
433 0.00015429955965373665
434 0.0001505605468992144
435 0.00014691019896417856
436 0.00014335215382743627
437 0.00013987701095174998
438 0.0001364911295240745
439 0.00013318308629095554
440 0.00012996031728107482
441 0.00012681707448791713
442 0.0001237515825778246
443 0.00012076242273906246
444 0.0001178424499812536
445 0.0001149943345808424
446 0.00011221617023693398
447 0.00010950590512948111
448 0.00010686153109418228
449 0.00010428134555695578
450 0.00010176430077990517
451 9.930722444551066e-05
452 9.691028390079737e-05
453 9.457595297135413e-05
454 9.22925173654221e-05
455 9.006822801893577e-05
456 8.79026047186926e-05
457 8.57846753206104e-05
458 8.371831063413993e-05
459 8.170259388862178e-05
460 7.973532046889886e-05
461 7.781643944326788e-05
462 7.594288035761565e-05
463 7.411641854560003e-05
464 7.233369251480326e-05
465 7.059406925691292e-05
466 6.889844371471554e-05
467 6.724218837916851e-05
468 6.56265692668967e-05
469 6.405053864000365e-05
470 6.251001468626782e-05
471 6.100907557993196e-05
472 5.9547543060034513e-05
473 5.8119461755268276e-05
474 5.6722430599620566e-05
475 5.536193202715367e-05
476 5.403558679972775e-05
477 5.274035356706008e-05
478 5.1476214139256626e-05
479 5.0242502766195685e-05
480 4.903853914584033e-05
481 4.78645451948978e-05
482 4.6718425437575206e-05
483 4.559816079563461e-05
484 4.4506497943075374e-05
485 4.344211265561171e-05
486 4.240254929754883e-05
487 4.13884044974111e-05
488 4.039857958559878e-05
489 3.943321280530654e-05
490 3.848948108498007e-05
491 3.7569927371805534e-05
492 3.667309647426009e-05
493 3.5794582800008357e-05
494 3.4940730984089896e-05
495 3.410521094338037e-05
496 3.3291358704445884e-05
497 3.249421206419356e-05
498 3.171973003190942e-05
499 3.09631614072714e-05
发布了13 篇原创文章 · 获赞 0 · 访问量 89

猜你喜欢

转载自blog.csdn.net/qq_35283167/article/details/104639842