sklearn.linear_model the core source code parsing LinearRegression

Starting from the first to use LinearRegression code is as follows:

from sklearn import linear_model as lm
import numpy as np
import os
import pandas as pd


def read_data(path):
    """
    使用pandas读取数据
    """

    return pd.read_csv(path)


def train_model(train_data, features, labels):
    """
    根据训练数据集训练模型,并返回训练好的模型
    :param train_data: 
    :param features: 
    :param labels:  
    :return:
    """

    model = lm.LinearRegression()
    model.fit(train_data[features], train_data[labels])
    print(model.intercept_)
    print(model.coef_)
    
    return model


def linear_model(data, data_number):
    """

    :param data:
    :return:
    """
    # 特征的名称,和数据文件中第一行标题行对应
    features = ["x"]
    # 标签名称,和数据文件中第一行标题行对应
    labels = ["y"]
    # 将数据分为训练数据集和测试数据集,以data_number为分割线,下标0~data_number的为训练集
    train_data = data[:data_number]
    test_data = data[data_number:]
    # 训练模型
    model = train_model(train_data, features, labels)


if __name__ == "__main__":
    home_path = os.path.dirname(os.path.abspath(__file__))
    # Windows下的存储路径与Linux并不相同
    if os.name == "nt":
        dataPath = "%s\\data\\simple_example.csv" % home_path
    else:
        dataPath = "%s/data/simple_example.csv" % home_path
    data = read_data(dataPath)

    linear_model(data, data_number=15)

simple_example.csv data document reads as follows:

x,y
10,7.7
10,9.87
11,11.18
12,10.43
13,12.36
14,14.15
15,15.73
16,16.4
17,18.86
18,16.13
19,18.21
20,18.37
21,22.61
22,19.83

The basic idea of ​​the program

(1) using pandas csv format data read, read before 15, set by the parameter data_number;

(2) training model, the data is divided into two steps, the first step in defining a linear regression model is fit using the read second step, the core code is as follows:

model = lm.LinearRegression()
model.fit(train_data[features], train_data[labels])

Fit program results are as follows:

y = 0.62794705 1.01211289x-

 

Can be seen through fit method, the core of this method call is as follows:

linalg.lstsq (X, y)

It is scipy package provides a method for Ax = b Solutions solving least squares method, the method scipy source scipy / scipy / linalg / basic.py, the entire source code as follows:

# Linear Least Squares
def lstsq(a, b, cond=None, overwrite_a=False, overwrite_b=False,
          check_finite=True, lapack_driver=None):
    """
    省略了注释...

    """
    a1 = _asarray_validated(a, check_finite=check_finite)
    b1 = _asarray_validated(b, check_finite=check_finite)
    if len(a1.shape) != 2:
        raise ValueError('Input array a should be 2-D')
    m, n = a1.shape
    if len(b1.shape) == 2:
        nrhs = b1.shape[1]
    else:
        nrhs = 1
    if m != b1.shape[0]:
        raise ValueError('Shape mismatch: a and b should have the same number'
                         ' of rows ({} != {}).'.format(m, b1.shape[0]))
    if m == 0 or n == 0:  # Zero-sized problem, confuses LAPACK
        x = np.zeros((n,) + b1.shape[1:], dtype=np.common_type(a1, b1))
        if n == 0:
            residues = np.linalg.norm(b1, axis=0)**2
        else:
            residues = np.empty((0,))
        return x, residues, 0, np.empty((0,))

    driver = lapack_driver
    if driver is None:
        driver = lstsq.default_lapack_driver
    if driver not in ('gelsd', 'gelsy', 'gelss'):
        raise ValueError('LAPACK driver "%s" is not found' % driver)

    lapack_func, lapack_lwork = get_lapack_funcs((driver,
                                                 '%s_lwork' % driver),
                                                 (a1, b1))
    real_data = True if (lapack_func.dtype.kind == 'f') else False

    if m < n:
        # need to extend b matrix as it will be filled with
        # a larger solution matrix
        if len(b1.shape) == 2:
            b2 = np.zeros((n, nrhs), dtype=lapack_func.dtype)
            b2[:m, :] = b1
        else:
            b2 = np.zeros(n, dtype=lapack_func.dtype)
            b2[:m] = b1
        b1 = b2

    overwrite_a = overwrite_a or _datacopied(a1, a)
    overwrite_b = overwrite_b or _datacopied(b1, b)

    if cond is None:
        cond = np.finfo(lapack_func.dtype).eps

    if driver in ('gelss', 'gelsd'):
        if driver == 'gelss':
            lwork = _compute_lwork(lapack_lwork, m, n, nrhs, cond)
            v, x, s, rank, work, info = lapack_func(a1, b1, cond, lwork,
                                                    overwrite_a=overwrite_a,
                                                    overwrite_b=overwrite_b)

        elif driver == 'gelsd':
            if real_data:
                lwork, iwork = _compute_lwork(lapack_lwork, m, n, nrhs, cond)
                x, s, rank, info = lapack_func(a1, b1, lwork,
                                               iwork, cond, False, False)
            else:  # complex data
                lwork, rwork, iwork = _compute_lwork(lapack_lwork, m, n,
                                                     nrhs, cond)
                x, s, rank, info = lapack_func(a1, b1, lwork, rwork, iwork,
                                               cond, False, False)
        if info > 0:
            raise LinAlgError("SVD did not converge in Linear Least Squares")
        if info < 0:
            raise ValueError('illegal value in %d-th argument of internal %s'
                             % (-info, lapack_driver))
        resids = np.asarray([], dtype=x.dtype)
        if m > n:
            x1 = x[:n]
            if rank == n:
                resids = np.sum(np.abs(x[n:])**2, axis=0)
            x = x1
        return x, resids, rank, s

    elif driver == 'gelsy':
        lwork = _compute_lwork(lapack_lwork, m, n, nrhs, cond)
        jptv = np.zeros((a1.shape[1], 1), dtype=np.int32)
        v, x, j, rank, info = lapack_func(a1, b1, jptv, cond,
                                          lwork, False, False)
        if info < 0:
            raise ValueError("illegal value in %d-th argument of internal "
                             "gelsy" % -info)
        if m > n:
            x1 = x[:n]
            x = x1
        return x, np.array([], x.dtype), rank, None


lstsq.default_lapack_driver = 'gelsd'

Provides three ways to solve: 'gelsd', 'gelsy', 'gelss', we found did not provide a way to pass parameters you want to use in the example, the default gelsd. So we can focus to the following code:

elif driver == 'gelsd':
            if real_data:
                lwork, iwork = _compute_lwork(lapack_lwork, m, n, nrhs, cond)
                x, s, rank, info = lapack_func(a1, b1, lwork,
                                               iwork, cond, False, False)
            else:  # complex data
                lwork, rwork, iwork = _compute_lwork(lapack_lwork, m, n,
                                                     nrhs, cond)
                x, s, rank, info = lapack_func(a1, b1, lwork, rwork, iwork,
                                               cond, False, False)

Given that the real number, so call if the _compute_lwork and lapack_func method. _compute_lwork in lapack.py under basic.py the same directory, the source code is as follows:

def _compute_lwork(routine, *args, **kwargs):
    """
    Round floating-point lwork returned by lapack to integer.

    Several LAPACK routines compute optimal values for LWORK, which
    they return in a floating-point variable. However, for large
    values of LWORK, single-precision floating point is not sufficient
    to hold the exact value --- some LAPACK versions (<= 3.5.0 at
    least) truncate the returned integer to single precision and in
    some cases this can be smaller than the required value.

    Examples
    --------
    >>> from scipy.linalg import lapack
    >>> n = 5000
    >>> s_r, s_lw = lapack.get_lapack_funcs(('sysvx', 'sysvx_lwork'))
    >>> lwork = lapack._compute_lwork(s_lw, n)
    >>> lwork
    32000

    """
    wi = routine(*args, **kwargs)
    if len(wi) < 2:
        raise ValueError('')
    info = wi[-1]
    if info != 0:
        raise ValueError("Internal work array size computation failed: "
                         "%d" % (info,))

    lwork = [w.real for w in wi[:-1]]

    dtype = getattr(routine, 'dtype', None)
    if dtype == _np.float32 or dtype == _np.complex64:
        # Single-precision routine -- take next fp value to work
        # around possible truncation in LAPACK code
        lwork = _np.nextafter(lwork, _np.inf, dtype=_np.float32)

    lwork = _np.array(lwork, _np.int64)
    if _np.any(_np.logical_or(lwork < 0, lwork > _np.iinfo(_np.int32).max)):
        raise ValueError("Too large work array required -- computation cannot "
                         "be performed with standard 32-bit LAPACK.")
    lwork = lwork.astype(_np.int32)
    if lwork.size == 1:
        return lwork[0]
    return lwork

lapack_func is derived from get_lapack_funcs, it also lapack.py, the source code is as follows:

def get_lapack_funcs(names, arrays=(), dtype=None):
    """
    省略部分注释...
    In LAPACK, the naming convention is that all functions start with a
    type prefix, which depends on the type of the principal
    matrix. These can be one of {'s', 'd', 'c', 'z'} for the numpy
    types {float32, float64, complex64, complex128} respectively, and
    are stored in attribute ``typecode`` of the returned functions.

    """
    return _get_funcs(names, arrays, dtype,
                      "LAPACK", _flapack, _clapack,
                      "flapack", "clapack", _lapack_alias)

The actual function call is _get_funcs, we see from the annotations, LAPACK naming convention in all types of functions to a prefix, this is determined by the initial matrix type, here the 's',' d ',' c ',' z 'four types, our data are float64, exactly corresponding to' d '. The bottom of the function name is called:

scipy.linalg.lapack.dgelsd

Figure:

The actual call is a fortran object. And this function in flapack_gen.pyf.src in the same directory, using fortran language, source code as follows:

subroutine <prefix2>gelsd(m,n,minmn,maxmn,nrhs,a,b,s,cond,r,work,lwork,size_iwork,iwork,info)
    ! x,s,rank,info = dgelsd(a,b,lwork,size_iwork,cond=-1.0,overwrite_a=True,overwrite_b=True)
    ! Solve Minimize 2-norm(A * X - B).

    callstatement (*f2py_func)(&m,&n,&nrhs,a,&m,b,&maxmn,s,&cond,&r,work,&lwork,iwork,&info)
    callprotoargument int*,int*,int*,<ctype2>*,int*,<ctype2>*,int*,<ctype2>*,<ctype2>*,int*,<ctype2>*,int*,int*,int*

    integer intent(hide),depend(a):: m = shape(a,0)
    integer intent(hide),depend(a):: n = shape(a,1)
    integer intent(hide),depend(m,n):: minmn = MIN(m,n)
    integer intent(hide),depend(m,n):: maxmn = MAX(m,n)
    <ftype2> dimension(m,n),intent(in,copy) :: a

    integer depend(b),intent(hide):: nrhs = shape(b,1)
    <ftype2> dimension(maxmn,nrhs),check(maxmn==shape(b,0)),depend(maxmn) :: b
    intent(in,out,copy,out=x) b

    <ftype2> intent(in),optional :: cond=-1.0
    integer intent(out,out=rank) :: r
    <ftype2> intent(out),dimension(minmn),depend(minmn) :: s

    integer intent(in),check(lwork>=1) :: lwork
    ! Impossible to calculate lwork explicitly, need to obtain it from query call first
    ! Same for size_iwork
    <ftype2> dimension(lwork),intent(cache,hide),depend(lwork) :: work

    integer intent(in) :: size_iwork
    integer intent(cache,hide),dimension(MAX(1,size_iwork)),depend(size_iwork) :: iwork
    integer intent(out)::info

end subroutine <prefix2>gelsd

 

Published 75 original articles · won praise 54 · views 80000 +

Guess you like

Origin blog.csdn.net/L_15156024189/article/details/104774649