PyTorch学习(二)--梯度下降

教程视频:https://www.bilibili.com/video/av93206234

任务:模拟梯度下降算法,计算在x_data、y_data数据集下,y=w*x模型找到合适的w的值。
代码实现:

import numpy as np
import matplotlib.pyplot as plt

x_data = [1.0,2.0,3.0]
y_data = [2.0,4.0,6.0]

w=1.0#初始权重设置为1

def forward(x):
    return x * w

def cost(xs,ys):#计算MSE
    cost = 0
    for x,y in zip(xs,ys):
        y_pred = forward(x)
        cost += (y_pred-y)**2
    return cost/len(xs)

def gradient(xs,ys):
    grad = 0
    for x,y in zip(xs,ys):
        grad += 2*x*(x*w-y)
    return  grad/len(xs)

print('Predict (befortraining)',4,forward(4))

mse_list=[]
for epoch in range(100):#训练100次
    cost_val = cost(x_data,y_data)#为了绘图才算他
    grad_val = gradient(x_data,y_data)
    w -= 0.01 * grad_val
    mse_list.append(cost_val)
    print('Epoch:',epoch,'w=',w,'cost=',cost_val)
print('Predict(after training)',4,forward(4))

w_list=np.arange(0,100,1)
plt.plot(w_list,mse_list)
plt.ylabel('cost')
plt.xlabel('times')
plt.show()

结果:

Predict (befortraining) 4 4.0
Epoch: 0 w= 1.0933333333333333 cost= 4.666666666666667
Epoch: 1 w= 1.1779555555555554 cost= 3.8362074074074086
Epoch: 2 w= 1.2546797037037036 cost= 3.1535329869958857
Epoch: 3 w= 1.3242429313580246 cost= 2.592344272332262
Epoch: 4 w= 1.3873135910979424 cost= 2.1310222071581117
Epoch: 5 w= 1.4444976559288012 cost= 1.7517949663820642
Epoch: 6 w= 1.4963445413754464 cost= 1.440053319920117
Epoch: 7 w= 1.5433523841804047 cost= 1.1837878313441108
Epoch: 8 w= 1.5859728283235668 cost= 0.9731262101573632
Epoch: 9 w= 1.6246153643467005 cost= 0.7999529948031382
Epoch: 10 w= 1.659651263674342 cost= 0.6575969151946154
……
Epoch: 90 w= 1.9998658050763347 cost= 1.0223124683409346e-07
Epoch: 91 w= 1.9998783299358769 cost= 8.403862850836479e-08
Epoch: 92 w= 1.9998896858085284 cost= 6.908348768398496e-08
Epoch: 93 w= 1.9998999817997325 cost= 5.678969725349543e-08
Epoch: 94 w= 1.9999093168317574 cost= 4.66836551287917e-08
Epoch: 95 w= 1.9999177805941268 cost= 3.8376039345125727e-08
Epoch: 96 w= 1.9999254544053418 cost= 3.154680994333735e-08
Epoch: 97 w= 1.9999324119941766 cost= 2.593287985380858e-08
Epoch: 98 w= 1.9999387202080534 cost= 2.131797981222471e-08
Epoch: 99 w= 1.9999444396553017 cost= 1.752432687141379e-08
Predict(after training) 4 7.999777758621207

cost图像:
在这里插入图片描述

为了解决在机器学习过程中在遇到“鞍点”(即总体所有点的梯度和为0,导致w=w-0.01*0,w不会改变)而导致不能继续进行的问题。可以采用随机梯度下降,即随机的取一组(x,y)的梯度,作为梯度下降的依据,而不用总体所有点的梯度和,作为梯度下降的依据。

代码实现如下:

import numpy as np
import matplotlib.pyplot as plt

x_data = [1.0,2.0,3.0]
y_data = [2.0,4.0,6.0]

w=1.0

def forward(x):
    return x * w

def cost(xs,ys):#计算MSE
    cost = 0
    for x,y in zip(xs,ys):
        y_pred = forward(x)
        cost += (y_pred-y)**2
    return cost/len(xs)

def gradient(xs,ys):
    grad = 0
    for x,y in zip(xs,ys):
        grad += 2*x*(x*w-y)
    return  grad/len(xs)

print('Predict (befortraining)',4,forward(4))

mse_list=[]
for epoch in range(100):
    cost_val = cost(x_data,y_data)#为了绘图才算他
    grad_val = gradient(x_data,y_data)
    w -= 0.01 * grad_val
    mse_list.append(cost_val)
    print('Epoch:',epoch,'w=',w,'cost=',cost_val)
print('Predict(after training)',4,forward(4))

w_list=np.arange(0,100,1)
plt.plot(w_list,mse_list)
plt.ylabel('cost')
plt.xlabel('times')
plt.show()

结果:
Predict (befortraining) 4 4.0
grad: 1.0 2.0 -2.0
grad: 2.0 4.0 -7.84
grad: 3.0 6.0 -16.2288
Epoch: 0 w= 1.260688 loss= 4.919240100095999
grad: 1.0 2.0 -1.478624
grad: 2.0 4.0 -5.796206079999999
grad: 3.0 6.0 -11.998146585599997
Epoch: 1 w= 1.453417766656 loss= 2.688769240265834
grad: 1.0 2.0 -1.093164466688
grad: 2.0 4.0 -4.285204709416961
grad: 3.0 6.0 -8.87037374849311
Epoch: 2 w= 1.5959051959019805 loss= 1.4696334962911515
grad: 1.0 2.0 -0.8081896081960389
grad: 2.0 4.0 -3.1681032641284723
grad: 3.0 6.0 -6.557973756745939
Epoch: 3 w= 1.701247862192685 loss= 0.8032755585999681
grad: 1.0 2.0 -0.59750427561463
grad: 2.0 4.0 -2.3422167604093502
grad: 3.0 6.0 -4.848388694047353
Epoch: 4 w= 1.7791289594933983 loss= 0.43905614881022015
grad: 1.0 2.0 -0.44174208101320334
grad: 2.0 4.0 -1.7316289575717576
grad: 3.0 6.0 -3.584471942173538
Epoch: 5 w= 1.836707389300983 loss= 0.2399802903801062
grad: 1.0 2.0 -0.3265852213980338
grad: 2.0 4.0 -1.2802140678802925
grad: 3.0 6.0 -2.650043120512205
Epoch: 6 w= 1.8792758133988885 loss= 0.1311689630744999
grad: 1.0 2.0 -0.241448373202223
grad: 2.0 4.0 -0.946477622952715
grad: 3.0 6.0 -1.9592086795121197
Epoch: 7 w= 1.910747160155559 loss= 0.07169462478267678
grad: 1.0 2.0 -0.17850567968888198
grad: 2.0 4.0 -0.6997422643804168
grad: 3.0 6.0 -1.4484664872674653
Epoch: 8 w= 1.9340143044689266 loss= 0.03918700813247573
grad: 1.0 2.0 -0.13197139106214673
grad: 2.0 4.0 -0.5173278529636143
grad: 3.0 6.0 -1.0708686556346834
Epoch: 9 w= 1.9512159834655312 loss= 0.021418922423117836
grad: 1.0 2.0 -0.09756803306893769
grad: 2.0 4.0 -0.38246668963023644
grad: 3.0 6.0 -0.7917060475345892
Epoch: 10 w= 1.9639333911678687 loss= 0.01170720245384975
grad: 1.0 2.0 -0.07213321766426262
grad: 2.0 4.0 -0.2827622132439096
grad: 3.0 6.0 -0.5853177814148953
……
Epoch: 90 w= 1.9999999999988431 loss= 1.2047849775995315e-23
grad: 1.0 2.0 -2.3137047833188262e-12
grad: 2.0 4.0 -9.070078021977679e-12
grad: 3.0 6.0 -1.8779644506139448e-11
Epoch: 91 w= 1.9999999999991447 loss= 6.5840863393251405e-24
grad: 1.0 2.0 -1.7106316363424412e-12
grad: 2.0 4.0 -6.7057470687359455e-12
grad: 3.0 6.0 -1.3882228699912957e-11
Epoch: 92 w= 1.9999999999993676 loss= 3.5991747246272455e-24
grad: 1.0 2.0 -1.2647660696529783e-12
grad: 2.0 4.0 -4.957811938766099e-12
grad: 3.0 6.0 -1.0263789818054647e-11
Epoch: 93 w= 1.9999999999995324 loss= 1.969312363793734e-24
grad: 1.0 2.0 -9.352518759442319e-13
grad: 2.0 4.0 -3.666400516522117e-12
grad: 3.0 6.0 -7.58859641791787e-12
Epoch: 94 w= 1.9999999999996543 loss= 1.0761829795642296e-24
grad: 1.0 2.0 -6.914468997365475e-13
grad: 2.0 4.0 -2.7107205369247822e-12
grad: 3.0 6.0 -5.611511255665391e-12
Epoch: 95 w= 1.9999999999997444 loss= 5.875191475205477e-25
grad: 1.0 2.0 -5.111466805374221e-13
grad: 2.0 4.0 -2.0037305148434825e-12
grad: 3.0 6.0 -4.1460168631601846e-12
Epoch: 96 w= 1.999999999999811 loss= 3.2110109830478153e-25
grad: 1.0 2.0 -3.779199175824033e-13
grad: 2.0 4.0 -1.4814816040598089e-12
grad: 3.0 6.0 -3.064215547965432e-12
Epoch: 97 w= 1.9999999999998603 loss= 1.757455879087579e-25
grad: 1.0 2.0 -2.793321129956894e-13
grad: 2.0 4.0 -1.0942358130705543e-12
grad: 3.0 6.0 -2.2648549702353193e-12
Epoch: 98 w= 1.9999999999998967 loss= 9.608404711682446e-26
grad: 1.0 2.0 -2.0650148258027912e-13
grad: 2.0 4.0 -8.100187187665142e-13
grad: 3.0 6.0 -1.6786572132332367e-12
Epoch: 99 w= 1.9999999999999236 loss= 5.250973729513143e-26
Predict(after training) 4 7.9999999999996945

loss图像如下:

在这里插入图片描述

发布了10 篇原创文章 · 获赞 0 · 访问量 133

猜你喜欢

转载自blog.csdn.net/weixin_44841652/article/details/105027090
今日推荐