Linear model of machine learning - least squares method, gradient descent method

"""
@author: JacksonKim
@filename: linear_regression.py
@start: 2021/02/01
@end:   2021/02/02
"""

import numpy as np
import matplotlib.pyplot as plt

'''
1. 线性模型形式简单、易于建模,许多功能更为强大的非线性模型可在线性模型的基础上通过
引入层级结构或高维映射而得。此外,线性模型有很好的可解释性。
2. 线性回归算法是回归任务中比较简单的一种模型,其结构为 f(x) = WX + b
3. 本项目主要是学习自主实现线性回归模型,也就是不调包的形式
'''


class linear(object):
    def __init__(self, fit_intercept=True, solver='sgd', if_standard=True, epochs=10, eta=1e-2, batch_size=1):
        """
        :param fit_intercept: 是否训练bias
        :param solver:
        :param if_standard:
        :param epochs:
        :param eta:
        :param batch_size:
        """
        self.w = None
        self.fit_intercept = fit_intercept
        self.solver = solver
        self.if_standard = if_standard
        if if_standard:
            self.feature_mean = None
            self.feature_std = None
        self.epochs = epochs
        self.eta = eta
        self.batch_size = batch_size

    def init_params(self, n_features):
        """
        初始化参数
        :param n_features:
        :return:
        """
        self.w = np.random.random(size=(n_features, 1))

    def _fit_closed_from_solution(self, x, y):
        """
        直接求解闭式解
        :param x:
        :param y:
        :return:
        """
        self.w = np.linalg.pinv(x).dot(y)

    def _fit_sgd(self, x, y):
        """
        随机梯度下降求解
        :param x:
        :param y:
        :return:
        """
        x_y = np.c_[x, y]
        # 按batch_size更新w, b
        for _ in range(self.epochs):
            np.random.shuffle(x_y)
            for index in range(x_y.shape[0] // self.batch_size):
                batch_x_y = x_y[self.batch_size * index:self.batch_size * (index + 1)]
                batch_x = batch_x_y[:, :-1]
                batch_y = batch_x_y[:, -1:]

                dw = -2 * batch_x.T.dot(batch_y - batch_x.dot(self.w)) / self.batch_size
                self.w = self.w - self.eta * dw

    def fit(self, x, y):
        """
        训练模型
        :param x:
        :param y:
        :return:
        """
        # 是否归一化feature
        if self.if_standard:
            self.feature_mean = np.mean(x, axis=0)
            self.feature_std = np.std(x, axis=0) + 1e-8
            x = (x - self.feature_mean) / self.feature_std
        # 是否训练bias
        if self.fit_intercept:
            x = np.c_[x, np.ones_like(y)]
        # 初始化参数
        self.init_params(x.shape[1])
        # 训练模型
        if self.solver == 'closed_form':
            self._fit_closed_from_solution(x, y)
        elif self.solver == 'sgd':
            self._fit_sgd(x, y)

    def get_params(self):
        """
        输出原始数据的系数
        :return:
        """
        if self.fit_intercept:
            w = self.w[:-1]
            b = self.w[-1]
        else:
            w = self.w
            b = 0
        if self.if_standard:
            w = w / self.feature_std.reshape(-1, 1)
            b = b - w.T.dot(self.feature_mean.reshape(-1, 1))
        return w.reshape(-1), b

    def predict(self, x):
        """
        对数据进行预测
        :param x: ndarray格式数据 m x n
        :return:
        """
        if self.if_standard:
            x = (x - self.feature_mean) / self.feature_std
        if self.fit_intercept:
            x = np.c_[x, np.ones(shape=x.shape[0])]
        return x.dot(self.w)

    def plot_fit_boundary(self, x, y):
        """
        绘制拟合结果
        :param x:
        :param y:
        :return:
        """
        plt.scatter(x[:, 0], y)
        plt.plot(x[:, 0], self.predict(x), 'r')


# 测试
# 随机产生数据集
X = np.linspace(0, 100, 100)
X = np.c_[X, np.ones(100)]
W = np.array([3, 2])
Y = X.dot(W)
X = X.astype('float')
Y = Y.astype('float')
X[:, 0] += np.random.normal(size=X[:, 0].shape) * 3
Y = Y.reshape(100, 1)
lr = linear(solver='sgd')
lr.fit(X[:, :-1], Y)
predict = lr.predict(X[:, :-1])
# 查看w
print('w', lr.get_params())
# 查看标准差
print(np.std(Y - predict))
# 可视化结果
lr.plot_fit_boundary(X[:, :-1], Y)
plt.show()


Guess you like

Origin blog.csdn.net/charenCsdn/article/details/113543712