机器学习(一元线性回归模型)

模型:一元线性回归模型

回归分析:建立方程模拟两个或者多个变量之间是如何相互关联,被预测的变量称为因变量(结果),用来进行预测的变量称为

自变量(输入参数),当输入参数只有一个(输出1个)时,称为一元回归,反之当输入有多个(输出1个),称为多元回归;

一元线性回归模型如下所示:(我们只需确定此方程的两个参数即可)

第一个参数为截距,第二个参数为斜率(我们只需根据大量的数据集通过训练求解即可^_^)

为了求解上述参数(或许求解的参数更多)我们在这里引入代价函数(cost function),在这里以一元线性回归模型为例

上述式子不难理解,为真实值与预测值之间的差值的平方(当然也可以取绝对值,但为便于后续数值操作取其平方),最后

取所有训练集个数的平均值。而我们就是在取某个(些)参数时,使得上述函数取得最小值(误差最小!)结合下图直观理解。

 上述由一些参数构成的函数称为代价函数,我们的目标就是求解对应的参数使得代价函数达到最小值,最后确定模型;

梯度下降法(确定所要求解的参数,在神经网络模型中也会有所应用): 其具体步骤如下图所示

使用此方法最重要的就是确定代价函数,且此函数可以收敛到最小值(凸函数),对于初始化操作,一般情况下赋值为0即可

所谓的梯度优化,就是不断的更改参数值,使之最后到达一个全局(局部)最小值,参数更新过程如下;

一般情况下的参数求解,对每个参数求偏导数;

这一个参数更新要求是同步更新,即最后在对参数进行更新,这里的α为学习率,通常取值为0.01,0.001,0.03,0.003等

学习率不易过高也不能过低,当过高是会导致永远到达不了收敛点(发散),当过低时会导致收敛过慢,影响收敛速度;

在这里的一元线性回归模型我们确定的参数只有两个,其参数求解方式如下图所示:

实战 (fight)

在这里我们以一个生产成本对月产量的影响为例,确定两者之间的线性关系(当然未必一定是线性)

 下图红色部分是通过梯度下降求解出的拟合直线,蓝色为通过python自带的库sklearn拟合出的直线(两者大体相同)

import numpy as np  #导入numpy
import matplotlib.pyplot as plt #导入图像绘制库
from sklearn.linear_model import LinearRegression  #线性回归库

 上述为我们常用的三个库文件的导入方式,numpy(矩阵运算),pyplot(图像绘制),sklearn

#训练集,以及相关参数
dataX=[]# X数据集
dataY=[]# Y数据集
with open('DataSet') as f:
    for line in f:
        line=line.strip('\n').split(' ')
        dataX.append(float(line[0]))
        dataY.append(float(line[1]))

k=0 # 斜率
b=0 # 截距  初始化参数
learnrate= 0.03 # 学习率
step = 50 # 学习次数
#梯度下降法求解参数
for i in range(step):
    temp1=0
    temp0=0
    for j in range(len(dataX)):
        temp0=temp0+(k*dataX[j]+b-dataY[j])/len(dataX)
        temp1=temp1+(k*dataX[j]+b-dataY[j])*dataX[j]/len(dataX)
    k=k-learnrate*temp1
    b=b-learnrate*temp0 #求解参数
#通过sklearn库训练数据
model=LinearRegression()
model.fit(np.array(dataX).reshape(-1,1),np.array(dataY).reshape(-1,1))
plt.plot(dataX,model.predict(np.array(dataX).reshape(-1,1)),'b')
#绘制散点以及拟合直线图像
plt.xlabel('X')
plt.ylabel('Y')# 设置横纵坐标
plt.scatter(dataX,dataY)# 绘制散列表
plt.plot(dataX,k*np.array(dataX)+b,'r')# 绘制拟合直线
plt.show()

这里的scatter传入的参数为X和Y坐标列表,要求X和Y列表内元素个数相同,plot用于绘制曲线第一个参数为X坐标,第二

个参数为函数方程(Y坐标),第三个参数为曲线的颜色。

拟合出的一元线性回归方程为:y=1.274974x+0.879003;

发布了79 篇原创文章 · 获赞 81 · 访问量 5712

猜你喜欢

转载自blog.csdn.net/weixin_44638960/article/details/104027392