机器学习基础(八)之梯度下降1

梯度下降:简单来说就是我们对损失函数寻找到它的最小值,也就是其导数为零即损失函数的极值点,通过改变变量x的值使得损失函数取得最小值

Jupyter Notebook
梯度下降(一)

import numpy as np
from matplotlib import pyplot as plt
前期的数据准备呀,兄弟们

X = np.linspace(1,8,101)
Y = (X - 4) ** 2 -6
X = np.linspace(1,8,101)
Y = (X - 4) ** 2 -6

plt.plot(X,Y)
[<matplotlib.lines.Line2D at 0x1efcbd44c50>]```


在这里插入图片描述
在这里插入代码片


导数为零,说明为极值点,然后我们就可以利用 dJ 导数小于零慢慢判断此时 J 函数是否为最小,得出我们想求的最优
定义导数dJ 此时我们方程很好求导有木有,高中就会了吧

定义损失函数得导数求法
def dJ(seeta):
    return 2 * (theta - 4)
    
定义J函数即损失函数
def J(seeta):
    try:
        return (seeta - 4) ** 2 -6
    except:
        return float('inf') #如果方向反了,会出现溢出的情况,所以返回浮点数最大值来限制一下

step = 0.01  #每次要调整的步长
exit_limt = 1e-6  #因为浮点数的计算又是并不能够精确到0
seeta = 1 #从1开始运算呀
history_seeta = [seeta] #用来存放所有的seeta的取值情况的
n_iters = 1e4 #为了避免死循环设置最多循环几次
i_iters = 0 
​
while i_iters < n_iters:
    gradient = dJ(seeta) #算出导数
    last_seeta = seeta
    seeta = seeta - step * gradient  #通过seeta得变化得到损失函数J取得最小值
    history_seeta.append(seeta)
    
    if (abs(J(seeta) - J(last_seeta)) < exit_limt):
        break
        
    i_iters += 1
​
print(seeta)
print(J(seeta))
​
    
3.9951357787023483
-5.999976339351168

history_seeta
[1,
 1.06,
 1.1188,
 1.176424,
 1.23289552,
 1.2882376096,
 1.342472857408,
 1.39562340025984,
 1.447710932254643,
 1.4987567136095503,
 1.5487815793373594,
 1.5978059477506121,
 1.6458498287955998,
 1.692932832219688,
 1.739074175575294,
 1.7842926920637883,
 1.8286068382225125,
 1.8720347014580623,
 1.914594007428901,
 1.956302127280323,
 1.9971760847347164,
 2.037232563040022,
 2.0764879117792217,
 2.114958153543637,
 2.1526589904727644,
 2.189605810663309,
 2.225813694450043,
 2.261297420561042,
 2.296071472149821,
 2.3301500427068245,
 2.363547041852688,
 2.3962761010156344,
 2.4283505789953215,
 2.459783567415415,
 2.490587896067107,
 2.5207761381457647,
 2.5503606153828495,
 2.5793534030751926,
 2.6077663350136886,
 2.6356110083134148,
 2.6628987881471464,
 2.6896408123842033,
 2.715847996136519,
 2.7415310362137886,
 2.7667004154895127,
 2.7913664071797224,
 2.815539079036128,
 2.8392282974554055,
 2.8624437315062976,
 2.8851948568761716,
 2.9074909597386482,
 2.929341140543875,
 2.9507543177329976,
 2.9717392313783377,
 2.992304446750771,
 3.0124583578157553,
 3.03220919065944,
 3.051565006846251,
 3.0705337067093263,
 3.0891230325751398,
 3.107340571923637,
 3.1251937604851645,
 3.1426898852754612,
 3.159836087569952,
 3.176639365818553,
 3.193106578502182,
 3.2092444469321384,
 3.2250595579934958,
 3.2405583668336257,
 3.255747199496953,
 3.270632255507014,
 3.285219610396874,
 3.2995152181889367,
 3.313524913825158,
 3.3272544155486545,
 3.3407093272376813,
 3.3538951406929276,
 3.366817237879069,
 3.3794808931214875,
 3.3918912752590575,
 3.4040534497538766,
 3.415972380758799,
 3.427652933143623,
 3.4390998744807506,
 3.450317876991136,
 3.461311519451313,
 3.4720852890622864,
 3.4826435832810407,
 3.4929907116154197,
 3.5031308973831115,
 3.513068279435449,
 3.52280691384674,
 3.5323507755698054,
 3.541703760058409,
 3.5508696848572407,
 3.559852291160096,
 3.568655245336894,
 3.5772821404301562,
 3.585736497621553,
 3.594021767669122,
 3.60214133231574,
 3.6100985056694253,
 3.617896535556037,
 3.625538604844916,
 3.633027832748018,
 3.6403672760930577,
 3.6475599305711963,
 3.6546087319597724,
 3.661516557320577,
 3.6682862261741653,
 3.674920501650682,
 3.6814220916176685,
 3.687793649785315,
 3.694037776789609,
 3.700157021253817,
 3.7061538808287406,
 3.7120308032121656,
 3.7177901871479224,
 3.723434383404964,
 3.7289656957368646,
 3.7343863818221275,
 3.739698654185685,
 3.7449046811019713,
 3.7500065874799318,
 3.7550064557303333,
 3.7599063266157264,
 3.764708200083412,
 3.7694140360817436,
 3.7740257553601086,
 3.7785452402529063,
 3.782974335447848,
 3.787314848738891,
 3.7915685517641133,
 3.795737180728831,
 3.7998224371142544,
 3.8038259883719694,
 3.80774946860453,
 3.8115944792324394,
 3.8153625896477905,
 3.819055337854835,
 3.822674231097738,
 3.826220746475783,
 3.8296963315462675,
 3.8331024049153424,
 3.8364403568170355,
 3.839711549680695,
 3.842917318687081,
 3.8460589723133394,
 3.8491377928670727,
 3.852155037009731,
 3.8551119362695365,
 3.8580096975441456,
 3.8608495035932626,
 3.8636325135213974,
 3.8663598632509695,
 3.86903266598595,
 3.8716520126662313,
 3.8742189724129066,
 3.8767345929646484,
 3.8791999011053555,
 3.881615903083248,
 3.8839835850215834,
 3.886303913321152,
 3.8885778350547286,
 3.890806278353634,
 3.8929901527865614,
 3.89513034973083,
 3.8972277427362134,
 3.899283187881489,
 3.9012975241238594,
 3.903271573641382,
 3.9052061421685544,
 3.9071020193251833,
 3.9089599789386797,
 3.9107807793599063,
 3.912565163772708,
 3.914313860497254,
 3.916027583287309,
 3.9177070316215628,
 3.9193528909891313,
 3.920965833169349,
 3.922546516505962,
 3.924095586175843,
 3.9256136744523262,
 3.92710140096328,
 3.9285593729440142,
 3.929988185485134,
 3.9313884217754316,
 3.932760653339923,
 3.934105440273125,
 3.9354233314676623,
 3.9367148648383092,
 3.937980567541543,
 3.939220956190712,
 3.9404365370668977,
 3.9416278063255596,
 3.9427952501990484,
 3.9439393451950675,
 3.945060558291166,
 3.9461593471253424,
 3.9472361601828356,
 3.948291436979179,
 3.949325608239595,
 3.9503390960748033,
 3.951332314153307,
 3.952305667870241,
 3.953259554512836,
 3.9541943634225794,
 3.9551104761541276,
 3.956008266631045,
 3.956888101298424,
 3.957750339272456,
 3.9585953324870067,
 3.9594234258372665,
 3.9602349573205213,
 3.961030258174111,
 3.961809653010629,
 3.9625734599504163,
 3.963321990751408,
 3.96405555093638,
 3.9647744399176528,
 3.9654789511193,
 3.966169372096914,
 3.9668459846549755,
 3.967509064961876,
 3.9681588836626385,
 3.9687957059893857,
 3.969419791869598,
 3.9700313960322062,
 3.9706307681115622,
 3.971218152749331,
 3.971793789694344,
 3.9723579139004572,
 3.972910755622448,
 3.973452540509999,
 3.973983489699799,
 3.974503819905803,
 3.975013743507687,
 3.9755134686375335,
 3.976003199264783,
 3.976483135279487,
 3.9769534725738973,
 3.9774144031224195,
 3.977866115059971,
 3.9783087927587717,
 3.978742616903596,
 3.979167764565524,
 3.9795844092742136,
 3.9799927210887294,
 3.980392866666955,
 3.9807850093336157,
 3.9811693091469436,
 3.981545922964005,
 3.981915004504725,
 3.9822767044146303,
 3.9826311703263375,
 3.982978546919811,
 3.9833189759814145,
 3.9836525964617864,
 3.9839795445325508,
 3.9842999536418997,
 3.9846139545690615,
 3.9849216754776804,
 3.985223241968127,
 3.9855187771287643,
 3.985808401586189,
 3.9860922335544653,
 3.986370388883376,
 3.9866429811057085,
 3.9869101214835942,
 3.9871719190539223,
 3.9874284806728437,
 3.987679911059387,
 3.987926312838199,
 3.9881677865814353,
 3.9884044308498066,
 3.9886363422328106,
 3.9888636153881545,
 3.9890863430803916,
 3.9893046162187837,
 3.989518523894408,
 3.98972815341652,
 3.9899335903481896,
 3.990134918541226,
 3.9903322201704015,
 3.9905255757669935,
 3.9907150642516536,
 3.9909007629666204,
 3.991082747707288,
 3.9912610927531422,
 3.9914358708980795,
 3.991607153480118,
 3.9917750104105156,
 3.991939510202305,
 3.992100719998259,
 3.9922587055982937,
 3.992413531486328,
 3.992565260856601,
 3.992713955639469,
 3.9928596765266797,
 3.993002482996146,
 3.993142433336223,
 3.993279584669499,
 3.9934139929761088,
 3.993545713116587,
 3.993674798854255,
 3.99380130287717,
 3.9939252768196267,
 3.994046771283234,
 3.9941658358575696,
 3.9942825191404183,
 3.99439686875761,
 3.9945089313824576,
 3.9946187527548083,
 3.994726377699712,
 3.994831850145718,
 3.9949352131428033,
 3.9950365088799473,
 3.9951357787023483]

​

plt.plot(X,Y)
plt.plot(np.array(history_seeta),J(np.array(history_seeta)),marker="x",color="red")
[<matplotlib.lines.Line2D at 0x1efcbaff240>]

至此,梯度下降的原理已经OK了

len(history_seeta)
319

猜你喜欢

转载自blog.csdn.net/qq_37982109/article/details/87978589