梯度下降:简单来说就是我们对损失函数寻找到它的最小值,也就是其导数为零即损失函数的极值点,通过改变变量x的值使得损失函数取得最小值
Jupyter Notebook
梯度下降(一)
import numpy as np
from matplotlib import pyplot as plt
前期的数据准备呀,兄弟们
X = np.linspace(1,8,101)
Y = (X - 4) ** 2 -6
X = np.linspace(1,8,101)
Y = (X - 4) ** 2 -6
plt.plot(X,Y)
[<matplotlib.lines.Line2D at 0x1efcbd44c50>]```
在这里插入代码片
导数为零,说明为极值点,然后我们就可以利用 dJ 导数小于零慢慢判断此时 J 函数是否为最小,得出我们想求的最优
定义导数dJ 此时我们方程很好求导有木有,高中就会了吧
定义损失函数得导数求法
def dJ(seeta):
return 2 * (theta - 4)
定义J函数即损失函数
def J(seeta):
try:
return (seeta - 4) ** 2 -6
except:
return float('inf') #如果方向反了,会出现溢出的情况,所以返回浮点数最大值来限制一下
step = 0.01 #每次要调整的步长
exit_limt = 1e-6 #因为浮点数的计算又是并不能够精确到0
seeta = 1 #从1开始运算呀
history_seeta = [seeta] #用来存放所有的seeta的取值情况的
n_iters = 1e4 #为了避免死循环设置最多循环几次
i_iters = 0
while i_iters < n_iters:
gradient = dJ(seeta) #算出导数
last_seeta = seeta
seeta = seeta - step * gradient #通过seeta得变化得到损失函数J取得最小值
history_seeta.append(seeta)
if (abs(J(seeta) - J(last_seeta)) < exit_limt):
break
i_iters += 1
print(seeta)
print(J(seeta))
3.9951357787023483
-5.999976339351168
history_seeta
[1,
1.06,
1.1188,
1.176424,
1.23289552,
1.2882376096,
1.342472857408,
1.39562340025984,
1.447710932254643,
1.4987567136095503,
1.5487815793373594,
1.5978059477506121,
1.6458498287955998,
1.692932832219688,
1.739074175575294,
1.7842926920637883,
1.8286068382225125,
1.8720347014580623,
1.914594007428901,
1.956302127280323,
1.9971760847347164,
2.037232563040022,
2.0764879117792217,
2.114958153543637,
2.1526589904727644,
2.189605810663309,
2.225813694450043,
2.261297420561042,
2.296071472149821,
2.3301500427068245,
2.363547041852688,
2.3962761010156344,
2.4283505789953215,
2.459783567415415,
2.490587896067107,
2.5207761381457647,
2.5503606153828495,
2.5793534030751926,
2.6077663350136886,
2.6356110083134148,
2.6628987881471464,
2.6896408123842033,
2.715847996136519,
2.7415310362137886,
2.7667004154895127,
2.7913664071797224,
2.815539079036128,
2.8392282974554055,
2.8624437315062976,
2.8851948568761716,
2.9074909597386482,
2.929341140543875,
2.9507543177329976,
2.9717392313783377,
2.992304446750771,
3.0124583578157553,
3.03220919065944,
3.051565006846251,
3.0705337067093263,
3.0891230325751398,
3.107340571923637,
3.1251937604851645,
3.1426898852754612,
3.159836087569952,
3.176639365818553,
3.193106578502182,
3.2092444469321384,
3.2250595579934958,
3.2405583668336257,
3.255747199496953,
3.270632255507014,
3.285219610396874,
3.2995152181889367,
3.313524913825158,
3.3272544155486545,
3.3407093272376813,
3.3538951406929276,
3.366817237879069,
3.3794808931214875,
3.3918912752590575,
3.4040534497538766,
3.415972380758799,
3.427652933143623,
3.4390998744807506,
3.450317876991136,
3.461311519451313,
3.4720852890622864,
3.4826435832810407,
3.4929907116154197,
3.5031308973831115,
3.513068279435449,
3.52280691384674,
3.5323507755698054,
3.541703760058409,
3.5508696848572407,
3.559852291160096,
3.568655245336894,
3.5772821404301562,
3.585736497621553,
3.594021767669122,
3.60214133231574,
3.6100985056694253,
3.617896535556037,
3.625538604844916,
3.633027832748018,
3.6403672760930577,
3.6475599305711963,
3.6546087319597724,
3.661516557320577,
3.6682862261741653,
3.674920501650682,
3.6814220916176685,
3.687793649785315,
3.694037776789609,
3.700157021253817,
3.7061538808287406,
3.7120308032121656,
3.7177901871479224,
3.723434383404964,
3.7289656957368646,
3.7343863818221275,
3.739698654185685,
3.7449046811019713,
3.7500065874799318,
3.7550064557303333,
3.7599063266157264,
3.764708200083412,
3.7694140360817436,
3.7740257553601086,
3.7785452402529063,
3.782974335447848,
3.787314848738891,
3.7915685517641133,
3.795737180728831,
3.7998224371142544,
3.8038259883719694,
3.80774946860453,
3.8115944792324394,
3.8153625896477905,
3.819055337854835,
3.822674231097738,
3.826220746475783,
3.8296963315462675,
3.8331024049153424,
3.8364403568170355,
3.839711549680695,
3.842917318687081,
3.8460589723133394,
3.8491377928670727,
3.852155037009731,
3.8551119362695365,
3.8580096975441456,
3.8608495035932626,
3.8636325135213974,
3.8663598632509695,
3.86903266598595,
3.8716520126662313,
3.8742189724129066,
3.8767345929646484,
3.8791999011053555,
3.881615903083248,
3.8839835850215834,
3.886303913321152,
3.8885778350547286,
3.890806278353634,
3.8929901527865614,
3.89513034973083,
3.8972277427362134,
3.899283187881489,
3.9012975241238594,
3.903271573641382,
3.9052061421685544,
3.9071020193251833,
3.9089599789386797,
3.9107807793599063,
3.912565163772708,
3.914313860497254,
3.916027583287309,
3.9177070316215628,
3.9193528909891313,
3.920965833169349,
3.922546516505962,
3.924095586175843,
3.9256136744523262,
3.92710140096328,
3.9285593729440142,
3.929988185485134,
3.9313884217754316,
3.932760653339923,
3.934105440273125,
3.9354233314676623,
3.9367148648383092,
3.937980567541543,
3.939220956190712,
3.9404365370668977,
3.9416278063255596,
3.9427952501990484,
3.9439393451950675,
3.945060558291166,
3.9461593471253424,
3.9472361601828356,
3.948291436979179,
3.949325608239595,
3.9503390960748033,
3.951332314153307,
3.952305667870241,
3.953259554512836,
3.9541943634225794,
3.9551104761541276,
3.956008266631045,
3.956888101298424,
3.957750339272456,
3.9585953324870067,
3.9594234258372665,
3.9602349573205213,
3.961030258174111,
3.961809653010629,
3.9625734599504163,
3.963321990751408,
3.96405555093638,
3.9647744399176528,
3.9654789511193,
3.966169372096914,
3.9668459846549755,
3.967509064961876,
3.9681588836626385,
3.9687957059893857,
3.969419791869598,
3.9700313960322062,
3.9706307681115622,
3.971218152749331,
3.971793789694344,
3.9723579139004572,
3.972910755622448,
3.973452540509999,
3.973983489699799,
3.974503819905803,
3.975013743507687,
3.9755134686375335,
3.976003199264783,
3.976483135279487,
3.9769534725738973,
3.9774144031224195,
3.977866115059971,
3.9783087927587717,
3.978742616903596,
3.979167764565524,
3.9795844092742136,
3.9799927210887294,
3.980392866666955,
3.9807850093336157,
3.9811693091469436,
3.981545922964005,
3.981915004504725,
3.9822767044146303,
3.9826311703263375,
3.982978546919811,
3.9833189759814145,
3.9836525964617864,
3.9839795445325508,
3.9842999536418997,
3.9846139545690615,
3.9849216754776804,
3.985223241968127,
3.9855187771287643,
3.985808401586189,
3.9860922335544653,
3.986370388883376,
3.9866429811057085,
3.9869101214835942,
3.9871719190539223,
3.9874284806728437,
3.987679911059387,
3.987926312838199,
3.9881677865814353,
3.9884044308498066,
3.9886363422328106,
3.9888636153881545,
3.9890863430803916,
3.9893046162187837,
3.989518523894408,
3.98972815341652,
3.9899335903481896,
3.990134918541226,
3.9903322201704015,
3.9905255757669935,
3.9907150642516536,
3.9909007629666204,
3.991082747707288,
3.9912610927531422,
3.9914358708980795,
3.991607153480118,
3.9917750104105156,
3.991939510202305,
3.992100719998259,
3.9922587055982937,
3.992413531486328,
3.992565260856601,
3.992713955639469,
3.9928596765266797,
3.993002482996146,
3.993142433336223,
3.993279584669499,
3.9934139929761088,
3.993545713116587,
3.993674798854255,
3.99380130287717,
3.9939252768196267,
3.994046771283234,
3.9941658358575696,
3.9942825191404183,
3.99439686875761,
3.9945089313824576,
3.9946187527548083,
3.994726377699712,
3.994831850145718,
3.9949352131428033,
3.9950365088799473,
3.9951357787023483]
plt.plot(X,Y)
plt.plot(np.array(history_seeta),J(np.array(history_seeta)),marker="x",color="red")
[<matplotlib.lines.Line2D at 0x1efcbaff240>]
至此,梯度下降的原理已经OK了
len(history_seeta)
319