线性建模:最小二乘法(机器学习基础教程)

简单线性模型拟合

适合一维的数据

定义模型

y = w_0+w_1x \tag1

定义平均损失函数:

L = \frac {1}{N} \sum_1^N (y_n-(w_0+w_1x))^2
 \tag2

问题转化为求L最小时, w_0, w_1的值,即

\begin{gather}
\frac{∂L}{∂w_0} = 0 \\
\frac{∂L}{∂w_1} = 0 \\
\tag3
\end{gather}

求偏导然后联立求解

令:

\overline{x} = \frac{1}{N} \sum_1^Nx_n,
\overline{y} = \frac{1}{N} \sum_1^Ny_n 
\tag4

求得:

\begin{gather} 
    x_0 = \overline{y}-w_1\overline{x} \\ \\
    x_1 = \frac{\overline{xy}-\overline{x}\overline{y}}{\overline{x^2}-\overline{x}^2}
    \tag5
\end{gather}

更复杂的线性关系模型

利用矩阵推导更一般的模型

令:

\bf x_n=\left[
    \begin{matrix}
       1  \\
       x_n
    \end{matrix} 
    \right]
    ,
    \bf w = \left[
        \begin{matrix}
        w_0\\
        w_1
        \end{matrix}
    \right]
    
    \\ \\
    
    
    \tag{6}

得到:

\bf X=\left[
    \begin{matrix}
        \bf x_1^T \\
        \bf x_2^T \\
        \vdots \\
        \bf x_N^T \\
    \end{matrix}
\right]
= \left[
    \begin{matrix}
        1 & x_1 \\
        1 & x_2 \\
        \vdots & \vdots \\
        1 & x_N \\
    \end{matrix}
\right]
,
\bf y = \left[
    \begin{matrix}
        y_1 \\
        y_2 \\
        \vdots \\
        y_N
    \end{matrix}
\right]

平均损失函数可以表示为:

\bf L = \frac{1}{N}(t-\bf X\bf w)^T(t-\bf X\bf w) \\
    =\frac{1}{N}\bf w^T\bf X^T\bf X\bf w-\frac{2}{N}\bf {w}^T\bf X^T\bf y+\frac{1}{N}\bf y^T\bf y

求偏导:

\frac{∂\bf L}{∂\bf w}
= \frac{2}{N}\bf X^T\bf X\bf w - \frac{2}{N}\bf X^T\bf y = 0

得到:

\bf X^T\bf X\bf w = \bf X^T\bf y

即:

\bf w = (\bf X^T\bf X)^{-1} \bf X \bf y

简单线性模型拟合实现

推导出来结果后,代码比较简单了,我是用js写的

// 模型
// y = w0 + w1*x
export class Liner {
  constructor(public inputs: number[] = [], public outputs: number[] = []) {}
//   求w1
  getW1(): number {
    const xt = this.inputs.map((item: number, index: number) => {
      return item * this.outputs[index];
    });
    const xx = this.inputs.map(item => item * item);
    const xtMean = Liner.mean(xt);
    const _x = Liner.mean(this.inputs);
    const _y = Liner.mean(this.outputs);
    return (xtMean - _x * _y) / (Liner.mean(xx) - _x * _x);
  }
  // 求w0
  getW0() {
    return Liner.mean(this.outputs) - this.getW1() * Liner.mean(this.inputs);
  }
  // 预测值
  get(input: number) {
    return this.getW0() + this.getW1() * input;
  }
  // 求平均值
  static mean(arr: number[]): number {
    return arr.reduce((a, b) => a + b, 0) / arr.length;
  }
}
复制代码

猜你喜欢

转载自juejin.im/post/5bb21a1a6fb9a05d1f221df7