简单线性回归算法

一、目标

              寻找一条直线,最大程度的“拟合”样本特征和样本输出标记之间的关系。在回归问题中我们预测的是一个具体的数值,这个具体的数值是在一个连续的空间里的,如果想看两个特征的回归问题就需要在三维空间里进行观察。样本特征有多个的回归称为多元线性回归

损失函数

对a求偏导数:

最后得到的结果:

求a、b的Python代码:

封装SampleLinearRegression算法的代码实现

"""coding:utf-8"""
import numpy as np
class SimpleLinearRegression(object):
    def __init__(self):
        """初始化Simple Linear Regression 模型"""
        self.a_ = None
        self.b_ = None
    def fit(self,x_train,y_train):
        """根据训练数据集x_train,y_train训练Simple Linear Regression模型"""
        assert x_train.ndim == 1, \
            "Simple Linear Regressor can only solve single feature training data."
        assert len(x_train) == len(y_train), \
            "the size of x_train must be equal to the size of y_train"
        x_mean = np.mean(x_train)
        y_mean = np.mean(y_train)
        num = 0.0
        d = 0.0
        for x,y in zip(x_train,y_train):
            num += (x-x_mean)*(y-y_mean)
            d += (x-x_mean)**2

        self.a_ = num/d
        self.b_ = y_mean-self.a_*x_mean
        return self

    def predict(self,x_predict):
        """给定待预测数据集x_predict,返回表示x_predict的结果向量"""
        assert x_predict.ndim == 1, \
            "Simple Linear Regressor can only solve single feature training data."
        assert self.a_ is not None and self.b_ is not None, \
            "must fit before predict!"
        return np.array([self._predict(x) for x in x_predict])

    def _predict(self,x):
        """给定单个待预测数据x,返回x的预测结果值"""
        return self.a_ * x +self.b_

    def __repr__(self):
        return "SimpleLinearRegression1()"

检验封装算法的测试代码

"""coding:utf-8"""
import numpy as np
import matplotlib.pyplot as plt
X = np.array([1.,2.,3.,4.,5.])
y = np.array([1.,3.,2.,3.,5.])
plt.scatter(X,y)
plt.axis([0,6,0,6])
plt.show()
from play_ML.SimpleLinearRrgression import SimpleLinearRegression
slr = SimpleLinearRegression()
slr.fit(X,y)
y_hat = slr.predict(X)
plt.scatter(X,y)
plt.plot(X,y_hat,color="r")
plt.axis([0,6,0,6])
plt.show()

测试结果

                           

猜你喜欢

转载自blog.csdn.net/ITpfzl/article/details/82946486
今日推荐