PyTorch机器学习之反向传播

一、手动实现线性模型

首先,利用numpy库的random生成500个加噪声的二维数据坐标,其分布如下图所示,代码如下:

# 生成随机数
np.random.seed(42)
n_examples = 500
x_data = 2*np.random.randn(n_examples) + 1.5
y_data = 3*x_data + 2.5*np.random.randn(n_examples)+1.5
# 画图展示
plt.plot(x_data,y_data,'b.')
plt.xlabel('x')
plt.ylabel('y')
plt.grid()
plt.show()

在这里插入图片描述主要模块如下:

# 正向传播
def forward(x,w,b):
    return x*w+b
# 计算单点损失
def loss(x,y,w,b):
    y_predict = forward(x,w,b)
    return (y_predict-y)**2
# 计算所有数据均方损失和
def mse(w,b):
    l_sum = 0
    for x_val, y_val in zip(x_data, y_data):
        y_pred_val = forward(x_val,w,b)
        loss_val = loss(x_val, y_val, w, b)
        l_sum += loss_val
        # print('\t', x_val, y_val, y_pred_val, loss_val)
    mse = l_sum/len(x_data)
    print('MSE={0}'.format(mse))
    return mse

最后,画出均方损失随权重和偏置变化的三维曲面图,代码和结果图如下:

##画图
fig = plt.figure()
## 画在一个fig里
ax = Axes3D(fig,auto_add_to_figure=False)
fig.add_axes(ax)
##定义网格化数据
w_list = np.arange(-20,20,0.1)
b_list = np.arange(-20,20,0.1)
##生成网格化数据
xx, yy = np.meshgrid(w_list,b_list,sparse=False, indexing='xy')
##每个点的对应高度
zz = mse(xx,yy)
ax.plot_surface(xx, yy, zz, rstride=1, cstride=1, cmap=plt.get_cmap('rainbow'))
# ax.contourf(xx, yy, zz, zdir='z', offset=0.99987, cmap='summer')
#设置坐标轴
ax.set_xlabel('weight')
ax.set_ylabel('bias')
ax.set_zlabel('loss')
plt.show()

在这里插入图片描述

二、PyTorch实现反向传播

如题:
在这里插入图片描述意思是利用这个二次模型来预测数据,减小损失函数(MSE)的值。
代码如下:

import torch
import matplotlib.pyplot as plt
import os
os.environ["KMP_DUPLICATE_LIB_OK"]  =  "TRUE"
# 数据集
x_data = [1.0,2.0,3.0]
y_data = [2.0,4.0,6.0]
# 权重参数初始值均为1
w = torch.tensor([1.0,1.0,1.0])
w.requires_grad = True    # 需要计算梯度

# 前向传播
def forward(x):
    return w[0]*(x**2)+w[1]*x+w[2]
# 计算损失
def loss(x,y):
    y_pred = forward(x)
    return (y_pred-y) ** 2

# 训练模块
print('predict (before tranining) ',4, forward(4).item())
epoch_list = []
w_list = []
loss_list = []
for epoch in range(1000):
    for x,y in zip(x_data,y_data):
        l = loss(x,y)
        l.backward()		# 后向传播
        print('\tgrad: ',x,y,w.grad.data)
        w.data = w.data - 0.01 * w.grad.data		# 梯度下降
        
        w.grad.data.zero_()	# 梯度清零操作
        
    print('progress: ',epoch,l.item())
    epoch_list.append(epoch)
    w_list.append(w.data)
    loss_list.append(l.item())
print('predict (after tranining) ',4, forward(4).item())

# 绘图
plt.plot(epoch_list,loss_list,'b')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.grid()
plt.show()

运行结果如下:

predict (before tranining)  4 21.0
	grad:  1.0 2.0 tensor([2., 2., 2.])
	grad:  2.0 4.0 tensor([22.8800, 11.4400,  5.7200])
	grad:  3.0 6.0 tensor([77.0472, 25.6824,  8.5608])
progress:  0 18.321826934814453
	grad:  1.0 2.0 tensor([-1.1466, -1.1466, -1.1466])
	grad:  2.0 4.0 tensor([-15.5367,  -7.7683,  -3.8842])
	grad:  3.0 6.0 tensor([-30.4322, -10.1441,  -3.3814])
progress:  1 2.858394145965576
	grad:  1.0 2.0 tensor([0.3451, 0.3451, 0.3451])
	grad:  2.0 4.0 tensor([2.4273, 1.2137, 0.6068])
	grad:  3.0 6.0 tensor([19.4499,  6.4833,  2.1611])
progress:  2 1.1675907373428345
	grad:  1.0 2.0 tensor([-0.3224, -0.3224, -0.3224])
	grad:  2.0 4.0 tensor([-5.8458, -2.9229, -1.4614])
	grad:  3.0 6.0 tensor([-3.8829, -1.2943, -0.4314])
progress:  3 0.04653334245085716
	grad:  1.0 2.0 tensor([0.0137, 0.0137, 0.0137])
	grad:  2.0 4.0 tensor([-1.9141, -0.9570, -0.4785])
	grad:  3.0 6.0 tensor([6.8557, 2.2852, 0.7617])
progress:  4 0.14506366848945618
	grad:  1.0 2.0 tensor([-0.1182, -0.1182, -0.1182])
	grad:  2.0 4.0 tensor([-3.6644, -1.8322, -0.9161])
	grad:  3.0 6.0 tensor([1.7455, 0.5818, 0.1939])
progress:  5 0.009403289295732975
	grad:  1.0 2.0 tensor([-0.0333, -0.0333, -0.0333])
	grad:  2.0 4.0 tensor([-2.7739, -1.3869, -0.6935])
	grad:  3.0 6.0 tensor([4.0140, 1.3380, 0.4460])
progress:  6 0.04972923547029495
	grad:  1.0 2.0 tensor([-0.0501, -0.0501, -0.0501])
	grad:  2.0 4.0 tensor([-3.1150, -1.5575, -0.7788])
	grad:  3.0 6.0 tensor([2.8534, 0.9511, 0.3170])
progress:  7 0.025129113346338272
	grad:  1.0 2.0 tensor([-0.0205, -0.0205, -0.0205])
	grad:  2.0 4.0 tensor([-2.8858, -1.4429, -0.7215])
	grad:  3.0 6.0 tensor([3.2924, 1.0975, 0.3658])
progress:  8 0.03345605731010437
	grad:  1.0 2.0 tensor([-0.0134, -0.0134, -0.0134])
	grad:  2.0 4.0 tensor([-2.9247, -1.4623, -0.7312])
	grad:  3.0 6.0 tensor([2.9909, 0.9970, 0.3323])
progress:  9 0.027609655633568764
	grad:  1.0 2.0 tensor([0.0033, 0.0033, 0.0033])
	grad:  2.0 4.0 tensor([-2.8414, -1.4207, -0.7103])
	grad:  3.0 6.0 tensor([3.0377, 1.0126, 0.3375])
progress:  10 0.02848036028444767
	grad:  1.0 2.0 tensor([0.0148, 0.0148, 0.0148])
	grad:  2.0 4.0 tensor([-2.8174, -1.4087, -0.7043])
	grad:  3.0 6.0 tensor([2.9260, 0.9753, 0.3251])
progress:  11 0.02642466314136982
	grad:  1.0 2.0 tensor([0.0280, 0.0280, 0.0280])
	grad:  2.0 4.0 tensor([-2.7682, -1.3841, -0.6920])
	grad:  3.0 6.0 tensor([2.8915, 0.9638, 0.3213])
progress:  12 0.025804826989769936
	grad:  1.0 2.0 tensor([0.0397, 0.0397, 0.0397])
	grad:  2.0 4.0 tensor([-2.7330, -1.3665, -0.6832])
	grad:  3.0 6.0 tensor([2.8243, 0.9414, 0.3138])
progress:  13 0.02462013065814972
	grad:  1.0 2.0 tensor([0.0514, 0.0514, 0.0514])
	grad:  2.0 4.0 tensor([-2.6934, -1.3467, -0.6734])
	grad:  3.0 6.0 tensor([2.7756, 0.9252, 0.3084])
progress:  14 0.023777369409799576
	grad:  1.0 2.0 tensor([0.0624, 0.0624, 0.0624])
	grad:  2.0 4.0 tensor([-2.6580, -1.3290, -0.6645])
	grad:  3.0 6.0 tensor([2.7213, 0.9071, 0.3024])
progress:  15 0.0228563379496336
	grad:  1.0 2.0 tensor([0.0731, 0.0731, 0.0731])
	grad:  2.0 4.0 tensor([-2.6227, -1.3113, -0.6557])
	grad:  3.0 6.0 tensor([2.6725, 0.8908, 0.2969])
progress:  16 0.022044027224183083
	grad:  1.0 2.0 tensor([0.0833, 0.0833, 0.0833])
	grad:  2.0 4.0 tensor([-2.5893, -1.2946, -0.6473])
	grad:  3.0 6.0 tensor([2.6240, 0.8747, 0.2916])
progress:  17 0.02125072106719017
	grad:  1.0 2.0 tensor([0.0931, 0.0931, 0.0931])
	grad:  2.0 4.0 tensor([-2.5568, -1.2784, -0.6392])
	grad:  3.0 6.0 tensor([2.5780, 0.8593, 0.2864])
progress:  18 0.020513182505965233
	grad:  1.0 2.0 tensor([0.1025, 0.1025, 0.1025])
	grad:  2.0 4.0 tensor([-2.5258, -1.2629, -0.6314])
	grad:  3.0 6.0 tensor([2.5335, 0.8445, 0.2815])
progress:  19 0.019810274243354797
	grad:  1.0 2.0 tensor([0.1116, 0.1116, 0.1116])
	grad:  2.0 4.0 tensor([-2.4958, -1.2479, -0.6239])
	grad:  3.0 6.0 tensor([2.4908, 0.8303, 0.2768])
progress:  20 0.019148115068674088
	grad:  1.0 2.0 tensor([0.1203, 0.1203, 0.1203])
	grad:  2.0 4.0 tensor([-2.4669, -1.2335, -0.6167])
	grad:  3.0 6.0 tensor([2.4496, 0.8165, 0.2722])
progress:  21 0.018520694226026535
	grad:  1.0 2.0 tensor([0.1286, 0.1286, 0.1286])
	grad:  2.0 4.0 tensor([-2.4392, -1.2196, -0.6098])
	grad:  3.0 6.0 tensor([2.4101, 0.8034, 0.2678])
progress:  22 0.017927465960383415
	grad:  1.0 2.0 tensor([0.1367, 0.1367, 0.1367])
	grad:  2.0 4.0 tensor([-2.4124, -1.2062, -0.6031])
	grad:  3.0 6.0 tensor([2.3720, 0.7907, 0.2636])
progress:  23 0.01736525259912014
	grad:  1.0 2.0 tensor([0.1444, 0.1444, 0.1444])
	grad:  2.0 4.0 tensor([-2.3867, -1.1933, -0.5967])
	grad:  3.0 6.0 tensor([2.3354, 0.7785, 0.2595])
progress:  24 0.016833148896694183
	grad:  1.0 2.0 tensor([0.1518, 0.1518, 0.1518])
	grad:  2.0 4.0 tensor([-2.3619, -1.1810, -0.5905])
	grad:  3.0 6.0 tensor([2.3001, 0.7667, 0.2556])
progress:  25 0.01632905937731266
	grad:  1.0 2.0 tensor([0.1589, 0.1589, 0.1589])
	grad:  2.0 4.0 tensor([-2.3380, -1.1690, -0.5845])
	grad:  3.0 6.0 tensor([2.2662, 0.7554, 0.2518])
progress:  26 0.01585075818002224
	grad:  1.0 2.0 tensor([0.1657, 0.1657, 0.1657])
	grad:  2.0 4.0 tensor([-2.3151, -1.1575, -0.5788])
	grad:  3.0 6.0 tensor([2.2336, 0.7445, 0.2482])
progress:  27 0.015397666022181511
	grad:  1.0 2.0 tensor([0.1723, 0.1723, 0.1723])
	grad:  2.0 4.0 tensor([-2.2929, -1.1465, -0.5732])
	grad:  3.0 6.0 tensor([2.2022, 0.7341, 0.2447])
progress:  28 0.014967591501772404
	grad:  1.0 2.0 tensor([0.1786, 0.1786, 0.1786])
	grad:  2.0 4.0 tensor([-2.2716, -1.1358, -0.5679])
	grad:  3.0 6.0 tensor([2.1719, 0.7240, 0.2413])
progress:  29 0.014559715054929256
	grad:  1.0 2.0 tensor([0.1846, 0.1846, 0.1846])
	grad:  2.0 4.0 tensor([-2.2511, -1.1255, -0.5628])
	grad:  3.0 6.0 tensor([2.1429, 0.7143, 0.2381])
progress:  30 0.014172340743243694
	grad:  1.0 2.0 tensor([0.1904, 0.1904, 0.1904])
	grad:  2.0 4.0 tensor([-2.2313, -1.1157, -0.5578])
	grad:  3.0 6.0 tensor([2.1149, 0.7050, 0.2350])
progress:  31 0.013804304413497448
	grad:  1.0 2.0 tensor([0.1960, 0.1960, 0.1960])
	grad:  2.0 4.0 tensor([-2.2123, -1.1061, -0.5531])
	grad:  3.0 6.0 tensor([2.0879, 0.6960, 0.2320])
progress:  32 0.013455045409500599
	grad:  1.0 2.0 tensor([0.2014, 0.2014, 0.2014])
	grad:  2.0 4.0 tensor([-2.1939, -1.0970, -0.5485])
	grad:  3.0 6.0 tensor([2.0620, 0.6873, 0.2291])
progress:  33 0.013122711330652237
	grad:  1.0 2.0 tensor([0.2065, 0.2065, 0.2065])
	grad:  2.0 4.0 tensor([-2.1763, -1.0881, -0.5441])
	grad:  3.0 6.0 tensor([2.0370, 0.6790, 0.2263])
progress:  34 0.01280694268643856
	grad:  1.0 2.0 tensor([0.2114, 0.2114, 0.2114])
	grad:  2.0 4.0 tensor([-2.1592, -1.0796, -0.5398])
	grad:  3.0 6.0 tensor([2.0130, 0.6710, 0.2237])
progress:  35 0.012506747618317604
	grad:  1.0 2.0 tensor([0.2162, 0.2162, 0.2162])
	grad:  2.0 4.0 tensor([-2.1428, -1.0714, -0.5357])
	grad:  3.0 6.0 tensor([1.9899, 0.6633, 0.2211])
progress:  36 0.012220758944749832
	grad:  1.0 2.0 tensor([0.2207, 0.2207, 0.2207])
	grad:  2.0 4.0 tensor([-2.1270, -1.0635, -0.5317])
	grad:  3.0 6.0 tensor([1.9676, 0.6559, 0.2186])
progress:  37 0.01194891706109047
	grad:  1.0 2.0 tensor([0.2251, 0.2251, 0.2251])
	grad:  2.0 4.0 tensor([-2.1118, -1.0559, -0.5279])
	grad:  3.0 6.0 tensor([1.9462, 0.6487, 0.2162])
progress:  38 0.011689926497638226
	grad:  1.0 2.0 tensor([0.2292, 0.2292, 0.2292])
	grad:  2.0 4.0 tensor([-2.0971, -1.0485, -0.5243])
	grad:  3.0 6.0 tensor([1.9255, 0.6418, 0.2139])
progress:  39 0.01144315768033266
	grad:  1.0 2.0 tensor([0.2333, 0.2333, 0.2333])
	grad:  2.0 4.0 tensor([-2.0829, -1.0414, -0.5207])
	grad:  3.0 6.0 tensor([1.9057, 0.6352, 0.2117])
progress:  40 0.011208509095013142
	grad:  1.0 2.0 tensor([0.2371, 0.2371, 0.2371])
	grad:  2.0 4.0 tensor([-2.0693, -1.0346, -0.5173])
	grad:  3.0 6.0 tensor([1.8865, 0.6288, 0.2096])
progress:  41 0.0109840864315629
	grad:  1.0 2.0 tensor([0.2408, 0.2408, 0.2408])
	grad:  2.0 4.0 tensor([-2.0561, -1.0280, -0.5140])
	grad:  3.0 6.0 tensor([1.8681, 0.6227, 0.2076])
progress:  42 0.010770938359200954
	grad:  1.0 2.0 tensor([0.2444, 0.2444, 0.2444])
	grad:  2.0 4.0 tensor([-2.0434, -1.0217, -0.5108])
	grad:  3.0 6.0 tensor([1.8503, 0.6168, 0.2056])
progress:  43 0.010566935874521732
	grad:  1.0 2.0 tensor([0.2478, 0.2478, 0.2478])
	grad:  2.0 4.0 tensor([-2.0312, -1.0156, -0.5078])
	grad:  3.0 6.0 tensor([1.8332, 0.6111, 0.2037])
progress:  44 0.010372749529778957
	grad:  1.0 2.0 tensor([0.2510, 0.2510, 0.2510])
	grad:  2.0 4.0 tensor([-2.0194, -1.0097, -0.5048])
	grad:  3.0 6.0 tensor([1.8168, 0.6056, 0.2019])
progress:  45 0.010187389329075813
	grad:  1.0 2.0 tensor([0.2542, 0.2542, 0.2542])

	grad:  2.0 4.0 tensor([-2.0080, -1.0040, -0.5020])
	grad:  3.0 6.0 tensor([1.8009, 0.6003, 0.2001])
progress:  46 0.010010283440351486
	grad:  1.0 2.0 tensor([0.2572, 0.2572, 0.2572])
	grad:  2.0 4.0 tensor([-1.9970, -0.9985, -0.4992])
	grad:  3.0 6.0 tensor([1.7856, 0.5952, 0.1984])
progress:  47 0.00984097272157669
	grad:  1.0 2.0 tensor([0.2600, 0.2600, 0.2600])
	grad:  2.0 4.0 tensor([-1.9864, -0.9932, -0.4966])
	grad:  3.0 6.0 tensor([1.7709, 0.5903, 0.1968])
progress:  48 0.009679674170911312
	grad:  1.0 2.0 tensor([0.2628, 0.2628, 0.2628])
	grad:  2.0 4.0 tensor([-1.9762, -0.9881, -0.4940])
	grad:  3.0 6.0 tensor([1.7568, 0.5856, 0.1952])
progress:  49 0.009525291621685028
	grad:  1.0 2.0 tensor([0.2655, 0.2655, 0.2655])
	grad:  2.0 4.0 tensor([-1.9663, -0.9832, -0.4916])
	grad:  3.0 6.0 tensor([1.7431, 0.5810, 0.1937])
progress:  50 0.00937769003212452
	grad:  1.0 2.0 tensor([0.2680, 0.2680, 0.2680])
	grad:  2.0 4.0 tensor([-1.9568, -0.9784, -0.4892])
	grad:  3.0 6.0 tensor([1.7299, 0.5766, 0.1922])
progress:  51 0.009236648678779602
	grad:  1.0 2.0 tensor([0.2704, 0.2704, 0.2704])
	grad:  2.0 4.0 tensor([-1.9476, -0.9738, -0.4869])
	grad:  3.0 6.0 tensor([1.7172, 0.5724, 0.1908])
progress:  52 0.00910158734768629
	grad:  1.0 2.0 tensor([0.2728, 0.2728, 0.2728])
	grad:  2.0 4.0 tensor([-1.9387, -0.9694, -0.4847])
	grad:  3.0 6.0 tensor([1.7050, 0.5683, 0.1894])
progress:  53 0.00897257961332798
	grad:  1.0 2.0 tensor([0.2750, 0.2750, 0.2750])
	grad:  2.0 4.0 tensor([-1.9301, -0.9651, -0.4825])
	grad:  3.0 6.0 tensor([1.6932, 0.5644, 0.1881])
progress:  54 0.008848887868225574
	grad:  1.0 2.0 tensor([0.2771, 0.2771, 0.2771])
	grad:  2.0 4.0 tensor([-1.9219, -0.9609, -0.4805])
	grad:  3.0 6.0 tensor([1.6819, 0.5606, 0.1869])
progress:  55 0.008730598725378513
	grad:  1.0 2.0 tensor([0.2792, 0.2792, 0.2792])
	grad:  2.0 4.0 tensor([-1.9139, -0.9569, -0.4785])
	grad:  3.0 6.0 tensor([1.6709, 0.5570, 0.1857])
progress:  56 0.00861735362559557
	grad:  1.0 2.0 tensor([0.2811, 0.2811, 0.2811])
	grad:  2.0 4.0 tensor([-1.9062, -0.9531, -0.4765])
	grad:  3.0 6.0 tensor([1.6604, 0.5535, 0.1845])
progress:  57 0.008508718572556973
	grad:  1.0 2.0 tensor([0.2830, 0.2830, 0.2830])
	grad:  2.0 4.0 tensor([-1.8987, -0.9493, -0.4747])
	grad:  3.0 6.0 tensor([1.6502, 0.5501, 0.1834])
progress:  58 0.008404706604778767
	grad:  1.0 2.0 tensor([0.2848, 0.2848, 0.2848])
	grad:  2.0 4.0 tensor([-1.8915, -0.9457, -0.4729])
	grad:  3.0 6.0 tensor([1.6404, 0.5468, 0.1823])
progress:  59 0.008305158466100693
	grad:  1.0 2.0 tensor([0.2865, 0.2865, 0.2865])
	grad:  2.0 4.0 tensor([-1.8845, -0.9423, -0.4711])
	grad:  3.0 6.0 tensor([1.6309, 0.5436, 0.1812])
progress:  60 0.00820931326597929
	grad:  1.0 2.0 tensor([0.2882, 0.2882, 0.2882])
	grad:  2.0 4.0 tensor([-1.8778, -0.9389, -0.4694])
	grad:  3.0 6.0 tensor([1.6218, 0.5406, 0.1802])
progress:  61 0.008117804303765297
	grad:  1.0 2.0 tensor([0.2898, 0.2898, 0.2898])
	grad:  2.0 4.0 tensor([-1.8713, -0.9356, -0.4678])
	grad:  3.0 6.0 tensor([1.6130, 0.5377, 0.1792])
progress:  62 0.008029798977077007
	grad:  1.0 2.0 tensor([0.2913, 0.2913, 0.2913])
	grad:  2.0 4.0 tensor([-1.8650, -0.9325, -0.4662])
	grad:  3.0 6.0 tensor([1.6045, 0.5348, 0.1783])
progress:  63 0.007945418357849121
	grad:  1.0 2.0 tensor([0.2927, 0.2927, 0.2927])
	grad:  2.0 4.0 tensor([-1.8589, -0.9294, -0.4647])
	grad:  3.0 6.0 tensor([1.5962, 0.5321, 0.1774])
progress:  64 0.007864190265536308
	grad:  1.0 2.0 tensor([0.2941, 0.2941, 0.2941])
	grad:  2.0 4.0 tensor([-1.8530, -0.9265, -0.4632])
	grad:  3.0 6.0 tensor([1.5884, 0.5295, 0.1765])
progress:  65 0.007786744274199009
	grad:  1.0 2.0 tensor([0.2954, 0.2954, 0.2954])
	grad:  2.0 4.0 tensor([-1.8473, -0.9236, -0.4618])
	grad:  3.0 6.0 tensor([1.5807, 0.5269, 0.1756])
progress:  66 0.007711691781878471
	grad:  1.0 2.0 tensor([0.2967, 0.2967, 0.2967])
	grad:  2.0 4.0 tensor([-1.8417, -0.9209, -0.4604])
	grad:  3.0 6.0 tensor([1.5733, 0.5244, 0.1748])
progress:  67 0.007640169933438301
	grad:  1.0 2.0 tensor([0.2979, 0.2979, 0.2979])
	grad:  2.0 4.0 tensor([-1.8364, -0.9182, -0.4591])
	grad:  3.0 6.0 tensor([1.5662, 0.5221, 0.1740])
progress:  68 0.007570972666144371
	grad:  1.0 2.0 tensor([0.2991, 0.2991, 0.2991])
	grad:  2.0 4.0 tensor([-1.8312, -0.9156, -0.4578])
	grad:  3.0 6.0 tensor([1.5593, 0.5198, 0.1733])
progress:  69 0.007504733745008707
	grad:  1.0 2.0 tensor([0.3002, 0.3002, 0.3002])
	grad:  2.0 4.0 tensor([-1.8262, -0.9131, -0.4566])
	grad:  3.0 6.0 tensor([1.5527, 0.5176, 0.1725])
progress:  70 0.007440924644470215
	grad:  1.0 2.0 tensor([0.3012, 0.3012, 0.3012])
	grad:  2.0 4.0 tensor([-1.8214, -0.9107, -0.4553])
	grad:  3.0 6.0 tensor([1.5463, 0.5154, 0.1718])
progress:  71 0.007379599846899509
	grad:  1.0 2.0 tensor([0.3022, 0.3022, 0.3022])
	grad:  2.0 4.0 tensor([-1.8167, -0.9083, -0.4542])
	grad:  3.0 6.0 tensor([1.5401, 0.5134, 0.1711])
progress:  72 0.007320486940443516
	grad:  1.0 2.0 tensor([0.3032, 0.3032, 0.3032])
	grad:  2.0 4.0 tensor([-1.8121, -0.9060, -0.4530])
	grad:  3.0 6.0 tensor([1.5341, 0.5114, 0.1705])
progress:  73 0.007263725157827139
	grad:  1.0 2.0 tensor([0.3041, 0.3041, 0.3041])
	grad:  2.0 4.0 tensor([-1.8077, -0.9038, -0.4519])
	grad:  3.0 6.0 tensor([1.5283, 0.5094, 0.1698])
progress:  74 0.007209045812487602
	grad:  1.0 2.0 tensor([0.3050, 0.3050, 0.3050])
	grad:  2.0 4.0 tensor([-1.8034, -0.9017, -0.4508])
	grad:  3.0 6.0 tensor([1.5227, 0.5076, 0.1692])
progress:  75 0.007156429346650839
	grad:  1.0 2.0 tensor([0.3058, 0.3058, 0.3058])
	grad:  2.0 4.0 tensor([-1.7992, -0.8996, -0.4498])
	grad:  3.0 6.0 tensor([1.5173, 0.5058, 0.1686])
progress:  76 0.007105532102286816
	grad:  1.0 2.0 tensor([0.3066, 0.3066, 0.3066])
	grad:  2.0 4.0 tensor([-1.7952, -0.8976, -0.4488])
	grad:  3.0 6.0 tensor([1.5121, 0.5040, 0.1680])
progress:  77 0.00705681974068284
	grad:  1.0 2.0 tensor([0.3073, 0.3073, 0.3073])
	grad:  2.0 4.0 tensor([-1.7913, -0.8956, -0.4478])
	grad:  3.0 6.0 tensor([1.5070, 0.5023, 0.1674])
progress:  78 0.007009552326053381
	grad:  1.0 2.0 tensor([0.3081, 0.3081, 0.3081])
	grad:  2.0 4.0 tensor([-1.7875, -0.8937, -0.4469])
	grad:  3.0 6.0 tensor([1.5021, 0.5007, 0.1669])
progress:  79 0.006964194122701883
	grad:  1.0 2.0 tensor([0.3087, 0.3087, 0.3087])
	grad:  2.0 4.0 tensor([-1.7838, -0.8919, -0.4459])
	grad:  3.0 6.0 tensor([1.4974, 0.4991, 0.1664])
progress:  80 0.006920332089066505
	grad:  1.0 2.0 tensor([0.3094, 0.3094, 0.3094])
	grad:  2.0 4.0 tensor([-1.7802, -0.8901, -0.4450])
	grad:  3.0 6.0 tensor([1.4928, 0.4976, 0.1659])
progress:  81 0.006878111511468887
	grad:  1.0 2.0 tensor([0.3100, 0.3100, 0.3100])
	grad:  2.0 4.0 tensor([-1.7767, -0.8883, -0.4442])
	grad:  3.0 6.0 tensor([1.4884, 0.4961, 0.1654])
progress:  82 0.006837360095232725
	grad:  1.0 2.0 tensor([0.3106, 0.3106, 0.3106])
	grad:  2.0 4.0 tensor([-1.7733, -0.8867, -0.4433])
	grad:  3.0 6.0 tensor([1.4841, 0.4947, 0.1649])
progress:  83 0.006797831039875746
	grad:  1.0 2.0 tensor([0.3111, 0.3111, 0.3111])
	grad:  2.0 4.0 tensor([-1.7700, -0.8850, -0.4425])
	grad:  3.0 6.0 tensor([1.4800, 0.4933, 0.1644])
progress:  84 0.006760062649846077
	grad:  1.0 2.0 tensor([0.3117, 0.3117, 0.3117])
	grad:  2.0 4.0 tensor([-1.7668, -0.8834, -0.4417])
	grad:  3.0 6.0 tensor([1.4759, 0.4920, 0.1640])
progress:  85 0.006723103579133749
	grad:  1.0 2.0 tensor([0.3122, 0.3122, 0.3122])
	grad:  2.0 4.0 tensor([-1.7637, -0.8818, -0.4409])
	grad:  3.0 6.0 tensor([1.4720, 0.4907, 0.1636])
progress:  86 0.00668772729113698
	grad:  1.0 2.0 tensor([0.3127, 0.3127, 0.3127])
	grad:  2.0 4.0 tensor([-1.7607, -0.8803, -0.4402])
	grad:  3.0 6.0 tensor([1.4682, 0.4894, 0.1631])
progress:  87 0.006653300020843744
	grad:  1.0 2.0 tensor([0.3131, 0.3131, 0.3131])
	grad:  2.0 4.0 tensor([-1.7577, -0.8789, -0.4394])
	grad:  3.0 6.0 tensor([1.4646, 0.4882, 0.1627])
progress:  88 0.0066203586757183075
	grad:  1.0 2.0 tensor([0.3135, 0.3135, 0.3135])
	grad:  2.0 4.0 tensor([-1.7548, -0.8774, -0.4387])
	grad:  3.0 6.0 tensor([1.4610, 0.4870, 0.1623])
progress:  89 0.0065881176851689816
	grad:  1.0 2.0 tensor([0.3139, 0.3139, 0.3139])
	grad:  2.0 4.0 tensor([-1.7520, -0.8760, -0.4380])
	grad:  3.0 6.0 tensor([1.4576, 0.4859, 0.1620])
progress:  90 0.0065572685562074184
	grad:  1.0 2.0 tensor([0.3143, 0.3143, 0.3143])
	grad:  2.0 4.0 tensor([-1.7493, -0.8747, -0.4373])
	grad:  3.0 6.0 tensor([1.4542, 0.4847, 0.1616])
progress:  91 0.0065271081402897835
	grad:  1.0 2.0 tensor([0.3147, 0.3147, 0.3147])
	grad:  2.0 4.0 tensor([-1.7466, -0.8733, -0.4367])
	grad:  3.0 6.0 tensor([1.4510, 0.4837, 0.1612])
progress:  92 0.00649801641702652
	grad:  1.0 2.0 tensor([0.3150, 0.3150, 0.3150])
	grad:  2.0 4.0 tensor([-1.7441, -0.8720, -0.4360])
	grad:  3.0 6.0 tensor([1.4478, 0.4826, 0.1609])
progress:  93 0.0064699104987084866
	grad:  1.0 2.0 tensor([0.3153, 0.3153, 0.3153])
	grad:  2.0 4.0 tensor([-1.7415, -0.8708, -0.4354])
	grad:  3.0 6.0 tensor([1.4448, 0.4816, 0.1605])
progress:  94 0.006442630663514137
	grad:  1.0 2.0 tensor([0.3156, 0.3156, 0.3156])
	grad:  2.0 4.0 tensor([-1.7391, -0.8695, -0.4348])
	grad:  3.0 6.0 tensor([1.4418, 0.4806, 0.1602])
progress:  95 0.006416172254830599
	grad:  1.0 2.0 tensor([0.3159, 0.3159, 0.3159])
	grad:  2.0 4.0 tensor([-1.7366, -0.8683, -0.4342])
	grad:  3.0 6.0 tensor([1.4389, 0.4796, 0.1599])
progress:  96 0.006390606984496117
	grad:  1.0 2.0 tensor([0.3161, 0.3161, 0.3161])
	grad:  2.0 4.0 tensor([-1.7343, -0.8671, -0.4336])
	grad:  3.0 6.0 tensor([1.4361, 0.4787, 0.1596])
progress:  97 0.0063657015562057495
	grad:  1.0 2.0 tensor([0.3164, 0.3164, 0.3164])
	grad:  2.0 4.0 tensor([-1.7320, -0.8660, -0.4330])
	grad:  3.0 6.0 tensor([1.4334, 0.4778, 0.1593])
progress:  98 0.0063416799530386925
	grad:  1.0 2.0 tensor([0.3166, 0.3166, 0.3166])
	grad:  2.0 4.0 tensor([-1.7297, -0.8649, -0.4324])
	grad:  3.0 6.0 tensor([1.4308, 0.4769, 0.1590])
progress:  99 0.00631808303296566
predict (after tranining)  4 8.544171333312988

损失值随着迭代次数的增加呈递减趋势,如下图所示:
在这里插入图片描述可以看出:x=4时的预测值约为8.5,与真实值8有所差距,可通过提高迭代次数或者调整学习率、初始参数等方法来减小差距。

参考文献

[1] https://www.bilibili.com/video/av93365242

猜你喜欢

转载自blog.csdn.net/weixin_43821559/article/details/123296140