【Python机器学习】实验02 线性回归

线性回归

1. 单变量的线性回归

import pandas as pd 
import numpy as np
import matplotlib.pyplot as plt
plt.rcParams['font.sans-serif']=['SimHei'] #用来正常显示中文标签
plt.rcParams['axes.unicode_minus']=False #用来正常显示负号

1.1 数据读取

data=pd.read_csv("data/regress_data1.csv")
data.head()
人口 收益
0 6.1101 17.5920
1 5.5277 9.1302
2 8.5186 13.6620
3 7.0032 11.8540
4 5.8598 6.8233
#可视化人口与收益之间的关系
data.plot(kind="scatter",x="人口",y="收益")
plt.xlabel("人口",fontsize=10)
plt.ylabel("收益",fontsize=10)
plt.title("人口与收益之间的关系")
Text(0.5, 1.0, '人口与收益之间的关系')

0

1.2 训练数据的准备

data.insert(0,"ones",1)
data
ones 人口 收益
0 1 6.1101 17.59200
1 1 5.5277 9.13020
2 1 8.5186 13.66200
3 1 7.0032 11.85400
4 1 5.8598 6.82330
... ... ... ...
92 1 5.8707 7.20290
93 1 5.3054 1.98690
94 1 8.2934 0.14454
95 1 13.3940 9.05510
96 1 5.4369 0.61705

97 rows × 3 columns

col_num=data.shape[1]
m=data.shape[0]
#训练集中的特征
X=data.iloc[:,:col_num-1]
#训练集中的标签
y=data.iloc[:,col_num-1]
X=X.values
y=y.values
X.shape,y.shape
((97, 2), (97,))
y=y.reshape((m,1))
y.shape
(97, 1)
#初始化权重向量
w=np.zeros((col_num-1,1))
w.shape
(2, 1)

1.3 假设函数定义–假设函数是为了去预测

#估计yhat
def h(X,w):
    #X的维度m,col_num-1, w的维度col_num-1,1
    temp=X@w
    return temp

1.4 损失函数的定义

#定义MSE损失,均方损失函数
def cost(X,y,w):
    temp=h(X,w)
    cost=np.sum(np.square(temp-y))/(2*m)
    return cost
def computeCost(X,y,w):
    inner = np.power(((X @ w) - y), 2)# (m,n) @ (n, 1) -> (n, 1)
#     return np.sum(inner)/(2 * len(X))
    return np.sum(inner) / (2*m)
cost(X,y,w)
32.072733877455676
error=h(X,w)-y
error.shape
(97, 1)
x1=np.array([1,2]).reshape(2,1)
x2=np.array([3,4]).reshape(2,1)
np.multiply(x1,x2)
array([[3],
       [8]])
X[:,1].shape
(97,)
X.shape,w.shape,y.shape
((97, 2), (2, 1), (97, 1))
h(X,w)-y
array([[-17.592  ],
       [ -9.1302 ],
       [-13.662  ],
       [-11.854  ],
       [ -6.8233 ],
       [-11.886  ],
       [ -4.3483 ],
       [-12.     ],
       [ -6.5987 ],
       [ -3.8166 ],
       [ -3.2522 ],
       [-15.505  ],
       [ -3.1551 ],
       [ -7.2258 ],
       [ -0.71618],
       [ -3.5129 ],
       [ -5.3048 ],
       [ -0.56077],
       [ -3.6518 ],
       [ -5.3893 ],
       [ -3.1386 ],
       [-21.767  ],
       [ -4.263  ],
       [ -5.1875 ],
       [ -3.0825 ],
       [-22.638  ],
       [-13.501  ],
       [ -7.0467 ],
       [-14.692  ],
       [-24.147  ],
       [  1.22   ],
       [ -5.9966 ],
       [-12.134  ],
       [ -1.8495 ],
       [ -6.5426 ],
       [ -4.5623 ],
       [ -4.1164 ],
       [ -3.3928 ],
       [-10.117  ],
       [ -5.4974 ],
       [ -0.55657],
       [ -3.9115 ],
       [ -5.3854 ],
       [ -2.4406 ],
       [ -6.7318 ],
       [ -1.0463 ],
       [ -5.1337 ],
       [ -1.844  ],
       [ -8.0043 ],
       [ -1.0179 ],
       [ -6.7504 ],
       [ -1.8396 ],
       [ -4.2885 ],
       [ -4.9981 ],
       [ -1.4233 ],
       [  1.4211 ],
       [ -2.4756 ],
       [ -4.6042 ],
       [ -3.9624 ],
       [ -5.4141 ],
       [ -5.1694 ],
       [  0.74279],
       [-17.929  ],
       [-12.054  ],
       [-17.054  ],
       [ -4.8852 ],
       [ -5.7442 ],
       [ -7.7754 ],
       [ -1.0173 ],
       [-20.992  ],
       [ -6.6799 ],
       [ -4.0259 ],
       [ -1.2784 ],
       [ -3.3411 ],
       [  2.6807 ],
       [ -0.29678],
       [ -3.8845 ],
       [ -5.7014 ],
       [ -6.7526 ],
       [ -2.0576 ],
       [ -0.47953],
       [ -0.20421],
       [ -0.67861],
       [ -7.5435 ],
       [ -5.3436 ],
       [ -4.2415 ],
       [ -6.7981 ],
       [ -0.92695],
       [ -0.152  ],
       [ -2.8214 ],
       [ -1.8451 ],
       [ -4.2959 ],
       [ -7.2029 ],
       [ -1.9869 ],
       [ -0.14454],
       [ -9.0551 ],
       [ -0.61705]])
np.multiply((h(X,w)-y).ravel(),X[:,1]).shape
(97,)

1.5 利用梯度下降算法来优化参数w

#超参数为I,学习率alpha,对所有样本
def gradient_descent(X,y,w,iter_num,alpha):
    temp=np.zeros((col_num-1,1))
    cost_lst=[] 
    for i in range(iter_num):
        error=h(X,w)-y
        for j in range(col_num-1):
            incre=np.multiply(error.ravel(),X[:,j].ravel())
            temp[j,0]=w[j,0]-((alpha/m)*np.sum(incre))
        w=temp
        cost_lst.append(cost(X,y,w))
    return w,cost_lst        
iter_num=200
alpha=0.003
w=np.zeros((col_num-1,1))
w,cost_lst=gradient_descent(X,y,w,iter_num,alpha)
w
array([[-0.32791203],
       [ 0.83460252]])
cost
<function __main__.cost(X, y, w)>

1.6 可视化误差曲线

plt.plot(range(iter_num),cost_lst,"r-+")
plt.xlabel("迭代次数")
plt.ylabel("误差")
plt.show()

1

1.7 可视化回归线/回归平面

x=np.linspace(data["人口"].min(),data["人口"].max(),50)
y1=w[0,0]*1+w[1,0]*x
plt.plot(x,y1,"r-+",label="预测线")
plt.scatter(data["人口"],data["收益"], label='训练数据')
plt.xlabel("人口",fontsize=10)
plt.ylabel("收益",fontsize=10)
plt.title("人口与收益之间的关系")
plt.show()

2

w
array([[-0.32791203],
       [ 0.83460252]])

总结:

  1. 数据准备
  2. 初始化w
  3. 定义了假设函数
  4. 定义了损失函数或者代价函数
  5. 定义梯度下降算法
  6. 可视化分析

1.2 单变量的线性回归–基于sklearn试试?

X.shape,y.shape
((97, 2), (97, 1))
import sklearn
from sklearn import linear_model
reg=linear_model.LinearRegression()
reg.fit(X,y)
reg.coef_
array([[0.        , 1.19303364]])
w
array([[-0.32791203],
       [ 0.83460252]])
reg.intercept_
array([-3.89578088])
reg.get_params()
{'copy_X': True,
 'fit_intercept': True,
 'n_jobs': None,
 'normalize': 'deprecated',
 'positive': False}
reg.predict(X)-y
array([[-14.19822601],
       [ -6.4312488 ],
       [ -7.39480448],
       [ -7.39472766],
       [ -3.72814233],
       [ -5.78069914],
       [  0.67551586],
       [ -5.66181898],
       [ -2.75622606],
       [ -1.68207302],
       [ -0.33492365],
       [ -2.50265234],
       [ -0.21002596],
       [ -1.09007678],
       [  2.117584  ],
       [ -0.99087569],
       [ -1.60644452],
       [  1.66383102],
       [  0.12314824],
       [ -0.84937859],
       [  0.34942365],
       [ -1.47998891],
       [ -1.60890687],
       [ -1.53603074],
       [ -0.33916795],
       [ -3.93175849],
       [ -2.09254529],
       [  2.12958876],
       [ -2.86836958],
       [ -1.55385488],
       [  3.59050903],
       [ -2.03100498],
       [ -4.99636713],
       [  1.28383475],
       [ -0.64226232],
       [  1.00673223],
       [  1.6465002 ],
       [ -0.60007636],
       [  1.30099898],
       [ -1.81336092],
       [  1.99826273],
       [  0.40377318],
       [  4.68685703],
       [  0.55183747],
       [ -1.29245052],
       [  3.52022606],
       [ -2.9805617 ],
       [  1.18148451],
       [  2.05841276],
       [  1.69763436],
       [ -1.65046859],
       [  0.59688379],
       [  0.67268159],
       [  0.17687322],
       [  2.23616258],
       [  5.11170076],
       [  1.11395081],
       [ -1.77162904],
       [  3.24920096],
       [  1.96858198],
       [  1.46381825],
       [  3.02608828],
       [  3.56178204],
       [  1.83596469],
       [  1.66894398],
       [ -0.16942543],
       [  0.2563525 ],
       [  0.5407115 ],
       [  1.64788834],
       [ -0.62028352],
       [  1.51690814],
       [  0.82862438],
       [  1.9914178 ],
       [  1.38386093],
       [  4.78217995],
       [  3.61930412],
       [  1.21352255],
       [ -3.58846693],
       [  1.60884678],
       [  0.14027707],
       [  2.45981748],
       [  2.08994488],
       [  3.00817305],
       [  0.21510688],
       [ -1.46569296],
       [  2.02402528],
       [  0.25840658],
       [  2.33785705],
       [  2.53824205],
       [ -0.68114646],
       [  1.06859725],
       [  0.91903985],
       [ -4.09473826],
       [  0.44683982],
       [  5.85398435],
       [  3.02861175],
       [  1.97357374]])
reg.score(X,y)
0.7020315537841397

1.3 多变量线性回归

path = 'data/regress_data2.csv'
data2 = pd.read_csv(path)
data2.head()
面积 房间数 价格
0 2104 3 399900
1 1600 3 329900
2 2400 3 369000
3 1416 2 232000
4 3000 4 539900
data2=(data2-data2.mean())/data2.std()
data2.head()
面积 房间数 价格
0 0.130010 -0.223675 0.475747
1 -0.504190 -0.223675 -0.084074
2 0.502476 -0.223675 0.228626
3 -0.735723 -1.537767 -0.867025
4 1.257476 1.090417 1.595389

实验要求1 准备训练数据

data2.insert(0,"ones",1)
col_num2=data2.shape[1]
m2=data2.shape[0]
X2=data2.iloc[:,:-1].values
y2=data2.iloc[:,-1].values.reshape((data2.shape[0],1))
w2=np.zeros((X2.shape[1],1))
X2.shape,y2.shape,w2.shape
((47, 3), (47, 1), (3, 1))

实验要求2 调用前面的梯度下降算法

#定义MSE损失,均方损失函数
def cost2(X,y,w):
    temp=h(X,w)
    cost=np.sum(np.square(temp-y))/(2*m2)
    return cost
#超参数为I,学习率alpha,对所有样本
def gradient_descent(X,y,w,iter_num,alpha):
    temp=np.zeros((col_num2-1,1))
    cost_lst=[] 
    for i in range(iter_num):
        error=h(X,w)-y
        for j in range(col_num2-1):
            incre=np.multiply(error.ravel(),X[:,j].ravel())
            temp[j,0]=w[j,0]-((alpha/m2)*np.sum(incre))
        w=temp
        cost_lst.append(cost2(X,y,w))
    return w,cost_lst        
iter_num2=1000
alpha2=0.01
w2,cost_lst2=gradient_descent(X2,y2,w2,iter_num2,alpha2)
w2
array([[-1.03191687e-16],
       [ 8.78503652e-01],
       [-4.69166570e-02]])
cost_lst2
[0.4805491041076719,
 0.47198587701203876,
 0.46366461618706284,
 0.4555781400525299,
 0.44771948335326117,
 0.4400818906150644,
 0.43265880979889004,
 0.42544388614718714,
 0.41843095621663473,
 0.4116140420916035,
 0.4049873457728717,
 0.39854524373628347,
 0.3922822816562035,
 0.38619316928877434,
 0.3802727755101314,
 0.3745161235048873,
 0.36891838610032585,
 0.36347488124189714,
 0.3581810676057273,
 0.353032540343996,
 0.34802502695915444,
 0.3431543833030803,
 0.33841658969738386,
 0.3338077471711977,
 0.3293240738128865,
 0.32496190123222957,
 0.32071767112972566,
 0.3165879319697778,
 0.3125693357546089,
 0.3086586348958572,
 0.3048526791808924,
 0.301148412830983,
 0.29754287164853055,
 0.29403318025067643,
 0.290616549386659,
 0.2872902733363892,
 0.2840517273877804,
 0.2808983653904495,
 0.2778277173834725,
 0.2748373872949541,
 0.27192505071123485,
 0.2690884527136251,
 0.26632540578062175,
 0.26363378775362334,
 0.2610115398642199,
 0.2584566648211922,
 0.25596722495541263,
 0.2535413404208921,
 0.2511771874502738,
 0.24887299666312288,
 0.2466270514254147,
 0.2444376862586688,
 0.2423032852972262,
 0.24022228079221122,
 0.23819315166076374,
 0.23621442207916982,
 0.23428466011856308,
 0.23240247642190492,
 0.2305665229209955,
 0.22877549159230046,
 0.22702811325042083,
 0.2253231563780625,
 0.2236594259914031,
 0.22203576253978147,
 0.22045104083867165,
 0.21890416903493287,
 0.2173940876033578,
 0.2159197683735719,
 0.21448021358636368,
 0.21307445497855435,
 0.21170155289554488,
 0.21036059543069827,
 0.20905069759074849,
 0.2077710004864442,
 0.20652067054766643,
 0.20529889876227617,
 0.20410489993797507,
 0.20293791198648126,
 0.20179719522934592,
 0.20068203172475318,
 0.1995917246146703,
 0.19852559749172968,
 0.1974829937852473,
 0.19646327616579626,
 0.19546582596777473,
 0.1944900426294234,
 0.1935353431497636,
 0.19260116156194368,
 0.1916869484224977,
 0.1907921703160338,
 0.189916309374886,
 0.18905886281327441,
 0.18821934247553765,
 0.1873972743980089,
 0.1865921983841233,
 0.185803667592357,
 0.1850312481366081,
 0.1842745186986435,
 0.18353307015224665,
 0.18280650519871142,
 0.18209443801333897,
 0.18139649390260434,
 0.1807123089716706,
 0.18004152980193594,
 0.17938381313831125,
 0.17873882558593354,
 0.17810624331602853,
 0.17748575178064738,
 0.1768770454360068,
 0.17627982747417442,
 0.17569380956284517,
 0.17511871159296452,
 0.17455426143396124,
 0.17400019469635858,
 0.17345625450154267,
 0.17292219125846853,
 0.17239776244709656,
 0.1718827324083545,
 0.1713768721404272,
 0.170879959101184,
 0.17039177701655678,
 0.16991211569468956,
 0.16944077084568412,
 0.16897754390677372,
 0.16852224187275908,
 0.16807467713154944,
 0.16763466730465196,
 0.16720203509246181,
 0.16677660812420697,
 0.16635821881240656,
 0.1659467042117072,
 0.165541905881964,
 0.16514366975543857,
 0.16475184600798926,
 0.16436628893413294,
 0.16398685682586134,
 0.1636134118550983,
 0.16324581995968843,
 0.1628839507328093,
 0.16252767731570503,
 0.16217687629363958,
 0.16183142759497396,
 0.16149121439327116,
 0.1611561230123392,
 0.1608260428341213,
 0.1605008662093497,
 0.16018048837087778,
 0.15986480734960967,
 0.15955372389294997,
 0.1592471413856969,
 0.1589449657733043,
 0.1586471054874422,
 0.1583534713737859,
 0.1580639766219659,
 0.15777853669761477,
 0.15749706927644566,
 0.15721949418030304,
 0.1569457333151249,
 0.15667571061075897,
 0.15640935196257785,
 0.15614658517483707,
 0.15588733990572493,
 0.1556315476140529,
 0.15537914150753623,
 0.15513005649261752,
 0.15488422912578662,
 0.15464159756635176,
 0.15440210153061742,
 0.15416568224742783,
 0.15393228241503384,
 0.15370184615924348,
 0.15347431899281805,
 0.15324964777607533,
 0.15302778067866443,
 0.15280866714247615,
 0.1525922578456555,
 0.15237850466768188,
 0.15216736065548658,
 0.15195877999057447,
 0.1517527179571211,
 0.1515491309110149,
 0.15134797624981636,
 0.15114921238360698,
 0.1509527987067004,
 0.15075869557019017,
 0.15056686425530957,
 0.15037726694757766,
 0.15018986671170936,
 0.15000462746726564,
 0.1498215139650219,
 0.1496404917640332,
 0.14946152720937447,
 0.14928458741053677,
 0.14910964022045875,
 0.14893665421517455,
 0.14876559867406025,
 0.14859644356065968,
 0.1484291595040733,
 0.1482637177808928,
 0.14810009029766483,
 0.1479382495738683,
 0.14777816872538996,
 0.1476198214484827,
 0.1474631820041929,
 0.1473082252032421,
 0.14715492639134978,
 0.14700326143498368,
 0.14685320670752527,
 0.14670473907583773,
 0.14655783588722407,
 0.14641247495676463,
 0.1462686345550211,
 0.14612629339609778,
 0.14598543062604827,
 0.14584602581161768,
 0.14570805892930994,
 0.14557151035477142,
 0.14543636085247996,
 0.1453025915657318,
 0.14517018400691611,
 0.1450391200480697,
 0.14490938191170247,
 0.14478095216188633,
 0.14465381369559951,
 0.14452794973431848,
 0.1444033438158502,
 0.14427997978639776,
 0.14415784179285188,
 0.14403691427530232,
 0.14391718195976197,
 0.14379862985109793,
 0.14368124322616269,
 0.14356500762711993,
 0.14344990885495965,
 0.14333593296319558,
 0.1432230662517408,
 0.1431112952609561,
 0.14300060676586493,
 0.14289098777053136,
 0.14278242550259543,
 0.14267490740796127,
 0.14256842114563387,
 0.14246295458269945,
 0.142358495789446,
 0.14225503303461925,
 0.14215255478081018,
 0.1420510496799703,
 0.14195050656905087,
 0.14185091446576267,
 0.1417522625644519,
 0.14165454023209023,
 0.14155773700437435,
 0.141461842581932,
 0.14136684682663245,
 0.1412727397579967,
 0.14117951154970584,
 0.14108715252620416,
 0.14099565315939372,
 0.1409050040654189,
 0.14081519600153714,
 0.14072621986307407,
 0.1406380666804604,
 0.14055072761634763,
 0.1404641939628014,
 0.14037845713856925,
 0.1402935086864208,
 0.14020934027055865,
 0.14012594367409767,
 0.14004331079661067,
 0.1399614336517384,
 0.1398803043648625,
 0.13979991517083906,
 0.13972025841179114,
 0.13964132653495887,
 0.1395631120906052,
 0.13948560772997556,
 0.1394088062033102,
 0.13933270035790743,
 0.1392572831362367,
 0.13918254757409942,
 0.13910848679883667,
 0.13903509402758246,
 0.138962362565561,
 0.13889028580442658,
 0.13881885722064558,
 0.1387480703739184,
 0.138677918905641,
 0.1386083965374045,
 0.13853949706953195,
 0.1384712143796508,
 0.13840354242130104,
 0.13833647522257653,
 0.13827000688480018,
 0.1382041315812306,
 0.1381388435558005,
 0.13807413712188513,
 0.13801000666110028,
 0.1379464466221289,
 0.13788345151957584,
 0.13782101593284923,
 0.1377591345050687,
 0.13769780194199868,
 0.13763701301100714,
 0.13757676254004808,
 0.13751704541666765,
 0.13745785658703336,
 0.1373991910549853,
 0.1373410438811091,
 0.13728341018182993,
 0.1372262851285268,
 0.13716966394666721,
 0.13711354191496053,
 0.13705791436453063,
 0.13700277667810684,
 0.1369481242892326,
 0.13689395268149127,
 0.13684025738774938,
 0.13678703398941564,
 0.13673427811571623,
 0.1366819854429858,
 0.13663015169397344,
 0.13657877263716303,
 0.1365278440861087,
 0.1364773618987835,
 0.13642732197694216,
 0.13637772026549663,
 0.13632855275190497,
 0.13627981546557258,
 0.13623150447726529,
 0.13618361589853506,
 0.1361361458811566,
 0.13608909061657562,
 0.1360424463353681,
 0.13599620930670997,
 0.13595037583785746,
 0.13590494227363753,
 0.13585990499594844,
 0.1358152604232695,
 0.1357710050101804,
 0.13572713524689023,
 0.1356836476587745,
 0.1356405388059217,
 0.13559780528268783,
 0.1355554437172595,
 0.13551345077122479,
 0.13547182313915254,
 0.13543055754817865,
 0.1353896507576004,
 0.13534909955847768,
 0.13530890077324167,
 0.13526905125531047,
 0.1352295478887108,
 0.13519038758770774,
 0.13515156729643937,
 0.13511308398855892,
 0.13507493466688233,
 0.13503711636304225,
 0.1349996261371476,
 0.13496246107744925,
 0.1349256183000107,
 0.13488909494838502,
 0.13485288819329633,
 0.1348169952323271,
 0.13478141328961019,
 0.1347461396155259,
 0.1347111714864043,
 0.1346765062042316,
 0.13464214109636177,
 0.13460807351523252,
 0.13457430083808555,
 0.13454082046669139,
 0.13450762982707826,
 0.13447472636926538,
 0.1344421075670001,
 0.1344097709174989,
 0.13437771394119274,
 0.1343459341814757,
 0.1343144292044577,
 0.13428319659872076,
 0.1342522339750784,
 0.1342215389663395,
 0.1341911092270747,
 0.13416094243338644,
 0.1341310362826823,
 0.13410138849345168,
 0.13407199680504514,
 0.13404285897745766,
 0.1340139727911139,
 0.13398533604665733,
 0.13395694656474152,
 0.13392880218582504,
 0.13390090076996808,
 0.1338732401966332,
 0.1338458183644872,
 0.13381863319120707,
 0.1337916826132873,
 0.1337649645858507,
 0.1337384770824609,
 0.13371221809493786,
 0.13368618563317541,
 0.1336603777249612,
 0.13363479241579912,
 0.1336094277687338,
 0.13358428186417728,
 0.13355935279973807,
 0.13353463869005203,
 0.13351013766661574,
 0.1334858478776214,
 0.1334617674877945,
 0.13343789467823247,
 0.13341422764624633,
 0.13339076460520344,
 0.1333675037843725,
 0.13334444342877066,
 0.13332158179901155,
 0.13329891717115624,
 0.13327644783656503,
 0.13325417210175164,
 0.13323208828823854,
 0.13321019473241433,
 0.13318848978539244,
 0.13316697181287218,
 0.13314563919500016,
 0.1331244903262345,
 0.13310352361520966,
 0.13308273748460353,
 0.13306213037100528,
 0.13304170072478536,
 0.13302144700996643,
 0.13300136770409596,
 0.13298146129812038,
 0.13296172629625996,
 0.13294216121588615,
 0.132922764587399,
 0.1329035349541069,
 0.13288447087210706,
 0.13286557091016762,
 0.13284683364961072,
 0.1328282576841967,
 0.13280984162001042,
 0.1327915840753473,
 0.13277348368060227,
 0.13275553907815818,
 0.13273774892227672,
 0.13272011187898983,
 0.13270262662599233,
 0.1326852918525356,
 0.13266810625932268,
 0.13265106855840397,
 0.1326341774730744,
 0.13261743173777144,
 0.13260083009797413,
 0.13258437131010334,
 0.13256805414142278,
 0.13255187736994117,
 0.13253583978431552,
 0.1325199401837546,
 0.1325041773779249,
 0.13248855018685587,
 0.1324730574408471,
 0.13245769798037618,
 0.13244247065600748,
 0.13242737432830168,
 0.13241240786772626,
 0.13239757015456718,
 0.13238286007884084,
 0.1323682765402074,
 0.13235381844788482,
 0.1323394847205633,
 0.13232527428632143,
 0.13231118608254225,
 0.1322972190558306,
 0.13228337216193134,
 0.13226964436564798,
 0.13225603464076227,
 0.1322425419699548,
 0.13222916534472598,
 0.13221590376531783,
 0.13220275624063701,
 0.13218972178817762,
 0.13217679943394567,
 0.13216398821238384,
 0.13215128716629693,
 0.13213869534677797,
 0.13212621181313536,
 0.1321138356328202,
 0.13210156588135483,
 0.13208940164226146,
 0.132077342006992,
 0.1320653860748583,
 0.1320535329529629,
 0.1320417817561306,
 0.1320301316068407,
 0.13201858163516003,
 0.1320071309786758,
 0.13199577878243007,
 0.13198452419885442,
 0.1319733663877049,
 0.131962304515998,
 0.13195133775794732,
 0.1319404652949001,
 0.1319296863152752,
 0.1319190000145011,
 0.1319084055949546,
 0.1318979022659001,
 0.13188748924342955,
 0.1318771657504026,
 0.1318669310163877,
 0.13185678427760358,
 0.13184672477686088,
 0.13183675176350523,
 0.13182686449335979,
 0.13181706222866899,
 0.1318073442380425,
 0.13179770979639993,
 0.1317881581849157,
 0.1317786886909647,
 0.13176930060806835,
 0.13175999323584098,
 0.1317507658799369,
 0.13174161785199792,
 0.13173254846960125,
 0.13172355705620795,
 0.13171464294111176,
 0.13170580545938826,
 0.13169704395184512,
 0.1316883577649718,
 0.13167974625089038,
 0.13167120876730698,
 0.13166274467746283,
 0.13165435335008663,
 0.1316460341593466,
 0.13163778648480368,
 0.13162960971136442,
 0.13162150322923485,
 0.13161346643387453,
 0.13160549872595084,
 0.13159759951129413,
 0.13158976820085277,
 0.13158200421064908,
 0.13157430696173483,
 0.13156667588014856,
 0.13155911039687143,
 0.1315516099477853,
 0.13154417397362975,
 0.13153680191996023,
 0.13152949323710658,
 0.1315222473801313,
 0.13151506380878897,
 0.13150794198748567,
 0.13150088138523836,
 0.1314938814756356,
 0.13148694173679754,
 0.131480061651337,
 0.13147324070632052,
 0.1314664783932301,
 0.1314597742079246,
 0.1314531276506024,
 0.13144653822576377,
 0.13144000544217335,
 0.13143352881282383,
 0.13142710785489936,
 0.13142074208973897,
 0.131414431042801,
 0.13140817424362766,
 0.13140197122580938,
 0.13139582152695017,
 0.131389724688633,
 0.13138368025638517,
 0.13137768777964465,
 0.13137174681172598,
 0.13136585690978717,
 0.1313600176347961,
 0.13135422855149823,
 0.1313484892283835,
 0.13134279923765432,
 0.1313371581551934,
 0.13133156556053213,
 0.13132602103681912,
 0.13132052417078882,
 0.13131507455273095,
 0.13130967177645936,
 0.13130431543928223,
 0.13129900514197135,
 0.13129374048873282,
 0.13128852108717703,
 0.1312833465482897,
 0.13127821648640217,
 0.13127313051916345,
 0.13126808826751082,
 0.13126308935564196,
 0.1312581334109867,
 0.13125322006417925,
 0.13124834894903042,
 0.13124351970250062,
 0.13123873196467223,
 0.13123398537872316,
 0.1312292795908998,
 0.13122461425049098,
 0.13121998900980153,
 0.13121540352412622,
 0.13121085745172428,
 0.13120635045379372,
 0.13120188219444595,
 0.13119745234068078,
 0.1311930605623617,
 0.1311887065321909,
 0.1311843899256851,
 0.1311801104211511,
 0.13117586769966202,
 0.1311716614450333,
 0.13116749134379915,
 0.13116335708518903,
 0.13115925836110465,
 0.13115519486609695,
 0.13115116629734297,
 0.13114717235462375,
 0.13114321274030158,
 0.13113928715929773,
 0.1311353953190708,
 0.1311315369295945,
 0.13112771170333615,
 0.1311239193552353,
 0.13112015960268236,
 0.13111643216549748,
 0.1311127367659097,
 0.1311090731285363,
 0.13110544098036211,
 0.13110184005071918,
 0.13109827007126654,
 0.13109473077597045,
 0.1310912219010841,
 0.13108774318512836,
 0.13108429436887195,
 0.13108087519531228,
 0.13107748540965625,
 0.13107412475930125,
 0.1310707929938162,
 0.131067489864923,
 0.1310642151264781,
 0.13106096853445376,
 0.1310577498469203,
 0.1310545588240278,
 0.13105139522798825,
 0.1310482588230578,
 0.13104514937551928,
 0.13104206665366452,
 0.13103901042777755,
 0.13103598047011686,
 0.13103297655489882,
 0.13102999845828087,
 0.1310270459583445,
 0.13102411883507903,
 0.13102121687036491,
 0.13101833984795783,
 0.1310154875534722,
 0.13101265977436552,
 0.13100985629992212,
 0.13100707692123786,
 0.1310043214312044,
 0.13100158962449351,
 0.13099888129754225,
 0.1309961962485374,
 0.13099353427740043,
 0.13099089518577295,
 0.13098827877700148,
 0.13098568485612289,
 0.1309831132298502,
 0.13098056370655764,
 0.13097803609626688,
 0.13097553021063246,
 0.13097304586292796,
 0.130970582868032,
 0.13096814104241458,
 0.13096572020412311,
 0.13096332017276915,
 0.13096094076951487,
 0.13095858181705952,
 0.13095624313962653,
 0.13095392456295019,
 0.13095162591426274,
 0.13094934702228142,
 0.13094708771719588,
 0.13094484783065521,
 0.1309426271957558,
 0.13094042564702846,
 0.13093824302042653,
 0.1309360791533131,
 0.1309339338844496,
 0.13093180705398308,
 0.1309296985034348,
 0.13092760807568815,
 0.13092553561497697,
 0.1309234809668741,
 0.13092144397827965,
 0.13091942449740973,
 0.1309174223737851,
 0.1309154374582199,
 0.13091346960281078,
 0.13091151866092543,
 0.13090958448719198,
 0.13090766693748815,
 0.13090576586893035,
 0.13090388113986332,
 0.1309020126098491,
 0.130900160139657,
 0.1308983235912531,
 0.13089650282778978,
 0.13089469771359577,
 0.13089290811416598,
 0.13089113389615128,
 0.13088937492734887,
 0.13088763107669207,
 0.130885902214241,
 0.13088418821117243,
 0.13088248893977056,
 0.13088080427341722,
 0.13087913408658264,
 0.1308774782548159,
 0.1308758366547358,
 0.13087420916402173,
 0.1308725956614043,
 0.13087099602665653,
 0.13086941014058484,
 0.13086783788502007,
 0.13086627914280877,
 0.13086473379780447,
 0.1308632017348589,
 0.13086168283981364,
 0.13086017699949143,
 0.13085868410168772,
 0.13085720403516238,
 0.13085573668963163,
 0.13085428195575916,
 0.13085283972514883,
 0.13085140989033592,
 0.1308499923447795,
 0.13084858698285434,
 0.13084719369984307,
 0.13084581239192838,
 0.13084444295618522,
 0.13084308529057329,
 0.1308417392939292,
 0.13084040486595921,
 0.13083908190723154,
 0.13083777031916902,
 0.13083647000404175,
 0.13083518086495988,
 0.1308339028058663,
 0.1308326357315295,
 0.13083137954753649,
 0.1308301341602858,
 0.13082889947698037,
 0.1308276754056209,
 0.13082646185499866,
 0.13082525873468898,
 0.13082406595504426,
 0.13082288342718762,
 0.1308217110630058,
 0.13082054877514324,
 0.13081939647699484,
 0.13081825408270004,
 0.13081712150713629,
 0.13081599866591254,
 0.13081488547536319,
 0.13081378185254178,
 0.1308126877152146,
 0.13081160298185487,
 0.1308105275716365,
 0.13080946140442812,
 0.13080840440078706,
 0.1308073564819534,
 0.1308063175698443,
 0.13080528758704793,
 0.13080426645681778,
 0.13080325410306715,
 0.13080225045036314,
 0.13080125542392118,
 0.13080026894959965,
 0.1307992909538939,
 0.13079832136393135,
 0.13079736010746548,
 0.13079640711287094,
 0.1307954623091379,
 0.13079452562586683,
 0.13079359699326334,
 0.13079267634213287,
 0.13079176360387565,
 0.1307908587104814,
 0.13078996159452447,
 0.1307890721891588,
 0.1307881904281127,
 0.13078731624568427,
 0.13078644957673607,
 0.13078559035669074,
 0.13078473852152586,
 0.1307838940077693,
 0.1307830567524944,
 0.13078222669331543,
 0.13078140376838285,
 0.13078058791637864,
 0.13077977907651198,
 0.13077897718851425,
 0.13077818219263512,
 0.1307773940296377,
 0.13077661264079418,
 0.13077583796788161,
 0.13077506995317736,
 0.130774308539455,
 0.13077355366997984,
 0.13077280528850505,
 0.13077206333926703,
 0.13077132776698147,
 0.1307705985168394,
 0.13076987553450267,
 0.1307691587661004,
 0.13076844815822458,
 0.13076774365792623,
 0.13076704521271165,
 0.13076635277053805,
 0.1307656662798101,
 0.13076498568937595,
 0.13076431094852323,
 0.13076364200697563,
 0.1307629788148889,
 0.13076232132284715,
 0.1307616694818592,
 0.13076102324335503,
 0.13076038255918201,
 0.1307597473816014,
 0.13075911766328474,
 0.1307584933573104,
 0.13075787441716,
 0.13075726079671504,
 0.13075665245025325,
 0.13075604933244553,
 0.13075545139835226,
 0.13075485860342015,
 0.13075427090347866,
 0.13075368825473718,
 0.13075311061378125,
 0.13075253793756964,
 0.13075197018343104,
 0.13075140730906076,
 0.13075084927251807,
 0.1307502960322223,
 0.13074974754695046,
 0.13074920377583368,
 0.13074866467835453,
 0.13074813021434364,
 0.130747600343977,
 0.13074707502777286,
 0.13074655422658876,
 0.1307460379016188,
 0.1307455260143904,
 0.13074501852676193,
 0.13074451540091928,
 0.1307440165993736,
 0.13074352208495807,
 0.13074303182082542,
 0.130742545770445,
 0.1307420638976003,
 0.1307415861663858,
 0.13074111254120485,
 0.13074064298676666,
 0.1307401774680837,
 0.13073971595046924,
 0.1307392583995346,
 0.13073880478118677,
 0.13073835506162565,
 0.13073790920734168,
 0.13073746718511345,
 0.13073702896200484,
 0.13073659450536299,
 0.13073616378281558,
 0.13073573676226868,
 0.130735313411904,
 0.13073489370017688,
 0.13073447759581372,
 0.1307340650678097,
 0.13073365608542659,
 0.13073325061819008,
 0.1307328486358882,
 0.13073245010856818,
 0.13073205500653512,
 0.13073166330034908,
 0.1307312749608232,
 0.13073088995902146,
 0.1307305082662567,
 0.13073012985408813,
 0.1307297546943194,
 0.13072938275899676,
 0.1307290140204064,
 0.13072864845107288,
 0.13072828602375694,
 0.13072792671145325,
 0.1307275704873888,
 0.1307272173250206,
 0.1307268671980337,
 0.13072652008033953,
 0.13072617594607358,
 0.13072583476959368,
 0.13072549652547816,
 0.1307251611885236,
 0.1307248287337435,
 0.13072449913636588,
 0.13072417237183181,
 0.13072384841579335,
 0.13072352724411188,
 0.1307232088328562,
 0.130722893158301,
 0.13072258019692454,
 0.13072226992540745,
 0.1307219623206308,
 0.13072165735967436,
 0.13072135501981477,
 0.13072105527852415,
 0.1307207581134681,
 0.13072046350250424,
 0.13072017142368056,
 0.13071988185523364,
 0.13071959477558712,
 0.1307193101633501,
 0.13071902799731558,
 0.13071874825645874,
 0.1307184709199356,
 0.13071819596708112,
 0.1307179233774081,
 0.13071765313060527,
 0.130717385206536,
 0.1307171195852367,
 0.13071685624691548,
 0.13071659517195028,
 0.13071633634088803,
 0.13071607973444263,
 0.13071582533349385,
 0.1307155731190857,
 0.13071532307242514,
 0.1307150751748808,
 0.13071482940798118,
 0.13071458575341377,
 0.1307143441930234,
 0.13071410470881087,
 0.13071386728293166,
 0.13071363189769475,
 0.13071339853556113,
 0.13071316717914244,
 0.13071293781119986,
 0.13071271041464275,
 0.1307124849725273,
 0.1307122614680553,
 0.13071203988457306,
 0.13071182020556996,
 0.13071160241467716,
 0.13071138649566671,
 0.13071117243244995,
 0.1307109602090767,
 0.13071074980973368,
 0.13071054121874362,
 0.13071033442056404,
 0.13071012939978588,
 0.1307099261411327,
 0.13070972462945923,
 0.13070952484975046,
 0.1307093267871204,
 0.13070913042681098,
 0.130708935754191,
 0.13070874275475497,
 0.1307085514141222,
 0.13070836171803538,
 0.13070817365235993,
 0.1307079872030827,
 0.13070780235631096,
 0.1307076190982714,
 0.1307074374153091,
 0.13070725729388652,
 0.13070707872058232,
 0.1307069016820908,
 0.13070672616522036,
 0.13070655215689286,
 0.13070637964414267,
 0.13070620861411547,
 0.1307060390540674,
 0.13070587095136435,
 0.1307057042934805,
 0.1307055390679979,
 0.13070537526260517,
 0.13070521286509698,
 0.13070505186337264,
 0.13070489224543566,
 0.1307047339993925,
 0.13070457711345204,
 0.13070442157592427,
 0.13070426737521984,
 0.1307041144998489,
 0.1307039629384204,
 0.13070381267964126,
 0.13070366371231526,
 0.13070351602534264,
 0.13070336960771892]

实验要求3 绘制误差曲线

plt.plot(range(iter_num2),cost_lst2,"r-+")
plt.xlabel("迭代次数")
plt.ylabel("误差")
plt.show()

3

1.4 最小二乘法求参数

最小二乘法的需要求解最优参数 w ∗ w^{*} w

已知:目标函数

J ( w ) = 1 2 m ∑ i = 1 m ( h ( x ( i ) ) − y ( i ) ) 2 J\left( w \right)=\frac{1}{2m}\sum\limits_{i=1}^{m}{ { {\left( {h}\left( {x^{(i)}} \right)-{y^{(i)}} \right)}^{2}}} J(w)=2m1i=1m(h(x(i))y(i))2

其中: h ( x ) = w T X = w 0 x 0 + w 1 x 1 + w 2 x 2 + . . . + w n x n {h}\left( x \right)={w^{T}}X={w_{0}}{x_{0}}+{w_{1}}{x_{1}}+{w_{2}}{x_{2}}+...+{w_{n}}{x_{n}} h(x)=wTX=w0x0+w1x1+w2x2+...+wnxn

将向量表达形式转为矩阵表达形式,则有 J ( w ) = 1 2 ( X w − y ) 2 J(w )=\frac{1}{2}{ {\left( Xw -y\right)}^{2}} J(w)=21(Xwy)2 ,其中 X X X m m m n + 1 n+1 n+1列的矩阵( m m m为样本个数, n n n为特征个数), w w w n + 1 n+1 n+1行1列的矩阵(包含了 w 0 w_0 w0), y y y m m m行1列的矩阵,则可以求得最优参数 w ∗ = ( X T X ) − 1 X T y w^{*} ={ {\left( {X^{T}}X \right)}^{-1}}{X^{T}}y w=(XTX)1XTy

梯度下降与最小二乘法的比较:

梯度下降:需要选择学习率 α \alpha α,需要多次迭代,当特征数量 n n n大时也能较好适用,适用于各种类型的模型

最小二乘法:不需要选择学习率 α \alpha α,一次计算得出,需要计算 ( X T X ) − 1 { {\left( { {X}^{T}}X \right)}^{-1}} (XTX)1,如果特征数量 n n n较大则运算代价大,因为矩阵逆的计算时间复杂度为 O ( n 3 ) O(n^3) O(n3),通常来说当 n n n小于10000 时还是可以接受的,只适用于线性模型,不适合逻辑回归模型等其他模型

def lsm(X,y):
    w=np.linalg.inv(X.T@X)@X.T@y
    return w
def lsm_v(X,y):
    w=np.linalg.inv(np.dot(X.T,X))
    w=np.dot(w,X.T)
    w=np.dot(w,y)
    return w
lsm(X,y)
array([[-3.89578088],
       [ 1.19303364]])
lsm_v(X,y)
array([[-3.89578088],
       [ 1.19303364]])

1.5 来点正则化?

1.5.1 普通的线性回归

from sklearn import linear_model
reg=linear_model.LinearRegression()
reg.fit(X,y)
LinearRegression()
#回到单变量的线性回归中来
x=X
y_1=reg.predict(x)
plt.plot(x,y_1,"r-+",label="预测线")
plt.scatter(data["人口"],data["收益"], label='训练数据')
plt.xlim(4.7,10)
plt.xlabel("人口",fontsize=10)
plt.ylabel("收益",fontsize=10)
plt.title("人口与收益之间的关系")
plt.show()

4

reg.coef_,reg.intercept_,reg.score(X,y)
(array([[0.        , 1.19303364]]), array([-3.89578088]), 0.7020315537841397)

1.5.2 岭回归

J ( w ) = 1 2 ∑ i = 1 m ( h w ( x ( i ) ) − y ( i ) ) 2 + λ ∑ j = 1 n w j 2 J ( { w } ) = \frac { 1 } { 2 } \sum _ { i = 1 } ^ { m } ( h _ { w} ( x ^ { ( i ) } ) - y ^ { ( i ) } ) ^ { 2 } + \lambda \sum _ { j = 1 } ^ { n } w_ { j } ^ { 2 } J(w)=21i=1m(hw(x(i))y(i))2+λj=1nwj2,此时称作Ridge Regression

from sklearn import linear_model
reg_rigde=linear_model.Ridge()
reg_rigde.fit(X,y)
Ridge()
#回到单变量的线性回归中来,Ridge
x=X
y_1=reg_rigde.predict(x)
plt.plot(x,y_1,"r-+",label="预测线")
plt.scatter(data["人口"],data["收益"], label='训练数据')
plt.xlim(4.7,10)
plt.xlabel("人口",fontsize=10)
plt.ylabel("收益",fontsize=10)
plt.title("人口与收益之间的关系")
plt.show()

5

reg_rigde.coef_,reg_rigde.intercept_,reg_rigde.score(X,y)
(array([[0.       , 1.1922044]]), array([-3.88901439]), 0.7020312146131912)

1.5.3 Lasso回归

J ( w ) = 1 2 ∑ i = 1 m ( h w ( x ( i ) ) − y ( i ) ) 2 + λ ∑ j = 1 n ∣ w j ∣ J ( {w } ) = \frac { 1 } { 2 } \sum _ { i = 1 } ^ { m } ( h _ { w} ( x ^ { ( i ) } ) - y ^ { ( i ) } ) ^ { 2 } + \lambda \sum _ { j = 1 } ^ { n } | w _ { j } | J(w)=21i=1m(hw(x(i))y(i))2+λj=1nwj,此时称作Lasso Regression

from sklearn import linear_model
reg_lasso=linear_model.Lasso()
reg_lasso.fit(X,y)
Lasso()
#回到单变量的线性回归中来,Lasso
x=X
y_1=reg_lasso.predict(x)
plt.plot(x,y_1,"r-+",label="预测线")
plt.scatter(data["人口"],data["收益"], label='训练数据')
plt.xlim(4.7,10)
plt.xlabel("人口",fontsize=10)
plt.ylabel("收益",fontsize=10)
plt.title("人口与收益之间的关系")
plt.show()

6

reg_lasso.coef_,reg_lasso.intercept_,reg_lasso.score(X,y)
(array([0.        , 1.12556458]), array([-3.34524677]), 0.6997863246152711)

实验要求4 手写代码实现单变量的L2正则化

J ( w ) = 1 2 ∑ i = 1 m ( h w ( x ( i ) ) − y ( i ) ) 2 + λ ∑ j = 1 n w j 2 J ( { w } ) = \frac { 1 } { 2 } \sum _ { i = 1 } ^ { m } ( h _ { w} ( x ^ { ( i ) } ) - y ^ { ( i ) } ) ^ { 2 } + \lambda \sum _ { j = 1 } ^ { n } w_ { j } ^ { 2 } J(w)=21i=1m(hw(x(i))y(i))2+λj=1nwj2,此时称作Ridge Regression

#超参数为I,学习率alpha,对所有样本
def gradient_descent_l2(X,y,w,iter_num,alpha,lambd):
    temp=np.zeros((col_num-1,1))
    cost_lst=[] 
    for i in range(iter_num):
        error=h(X,w)-y
        for j in range(col_num-1):
            incre=np.multiply(error.ravel(),X[:,j].ravel())
            temp[j,0]=w[j,0]-((alpha/m)*(np.sum(incre)+2*lambd*w[j,0]))
        w=temp
        cost_lst.append(cost(X,y,w))      
    return w,cost_lst  
iter_num=200
alpha=0.001
lambd=2
w=np.zeros((col_num-1,1))
w,cost_lst=gradient_descent_l2(X,y,w,iter_num,alpha,lambd)
plt.plot(range(iter_num),cost_lst,"r-+")
plt.xlabel("迭代次数")
plt.ylabel("误差")
plt.show()

7

猜你喜欢

转载自blog.csdn.net/m0_68111267/article/details/131892463