十一、加权线性回归案例:预测鲍鱼的年龄

加权线性回归案例:预测鲍鱼的年龄

点击文章标题即可获取源代码和笔记
数据集:https://download.csdn.net/download/weixin_44827418/12553408

1.导入数据集

数据集描述:
在这里插入图片描述

import pandas as pd
import numpy as np

abalone = pd.read_table("./datas/abalone.txt",header=None)
abalone.columns=['性别','长度','直径','高度','整体重量','肉重量','内脏重量','壳重','年龄']
abalone.head()
性别 长度 直径 高度 整体重量 肉重量 内脏重量 壳重 年龄
0 1 0.455 0.365 0.095 0.5140 0.2245 0.1010 0.150 15
1 1 0.350 0.265 0.090 0.2255 0.0995 0.0485 0.070 7
2 -1 0.530 0.420 0.135 0.6770 0.2565 0.1415 0.210 9
3 1 0.440 0.365 0.125 0.5160 0.2155 0.1140 0.155 10
4 0 0.330 0.255 0.080 0.2050 0.0895 0.0395 0.055 7
abalone.shape
(4177, 9)
abalone.info()
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 4177 entries, 0 to 4176
Data columns (total 9 columns):
 #   Column  Non-Null Count  Dtype  
---  ------  --------------  -----  
 0   性别      4177 non-null   int64  
 1   长度      4177 non-null   float64
 2   直径      4177 non-null   float64
 3   高度      4177 non-null   float64
 4   整体重量    4177 non-null   float64
 5   肉重量     4177 non-null   float64
 6   内脏重量    4177 non-null   float64
 7   壳重      4177 non-null   float64
 8   年龄      4177 non-null   int64  
dtypes: float64(7), int64(2)
memory usage: 293.8 KB
abalone.describe()
性别 长度 直径 高度 整体重量 肉重量 内脏重量 壳重 年龄
count 4177.000000 4177.000000 4177.000000 4177.000000 4177.000000 4177.000000 4177.000000 4177.000000 4177.000000
mean 0.052909 0.523992 0.407881 0.139516 0.828742 0.359367 0.180594 0.238831 9.933684
std 0.822240 0.120093 0.099240 0.041827 0.490389 0.221963 0.109614 0.139203 3.224169
min -1.000000 0.075000 0.055000 0.000000 0.002000 0.001000 0.000500 0.001500 1.000000
25% -1.000000 0.450000 0.350000 0.115000 0.441500 0.186000 0.093500 0.130000 8.000000
50% 0.000000 0.545000 0.425000 0.140000 0.799500 0.336000 0.171000 0.234000 9.000000
75% 1.000000 0.615000 0.480000 0.165000 1.153000 0.502000 0.253000 0.329000 11.000000
max 1.000000 0.815000 0.650000 1.130000 2.825500 1.488000 0.760000 1.005000 29.000000

2. 查看数据分布状况

import numpy as np
import pandas as pd
import random
import matplotlib as mpl
import matplotlib.pyplot as plt
plt.rcParams['font.sans-serif']=['simhei'] #显示中文
plt.rcParams['axes.unicode_minus']=False # 用来正常显示负号  
%matplotlib inline
mpl.cm.rainbow(np.linspace(0,1,10))
array([[5.00000000e-01, 0.00000000e+00, 1.00000000e+00, 1.00000000e+00],
       [2.80392157e-01, 3.38158275e-01, 9.85162233e-01, 1.00000000e+00],
       [6.07843137e-02, 6.36474236e-01, 9.41089253e-01, 1.00000000e+00],
       [1.66666667e-01, 8.66025404e-01, 8.66025404e-01, 1.00000000e+00],
       [3.86274510e-01, 9.84086337e-01, 7.67362681e-01, 1.00000000e+00],
       [6.13725490e-01, 9.84086337e-01, 6.41213315e-01, 1.00000000e+00],
       [8.33333333e-01, 8.66025404e-01, 5.00000000e-01, 1.00000000e+00],
       [1.00000000e+00, 6.36474236e-01, 3.38158275e-01, 1.00000000e+00],
       [1.00000000e+00, 3.38158275e-01, 1.71625679e-01, 1.00000000e+00],
       [1.00000000e+00, 1.22464680e-16, 6.12323400e-17, 1.00000000e+00]])
mpl.cm.rainbow(np.linspace(0,1,10))[0]
array([0.5, 0. , 1. , 1. ])
def dataPlot(dataSet):
    m,n = dataSet.shape
    fig = plt.figure(figsize=(8,20),dpi=100)
    colormap = mpl.cm.rainbow(np.linspace(0,1,n))
    for i in range(n):
        fig_ = fig.add_subplot(n,1,i+1)
        plt.scatter(range(m),dataSet.iloc[:,i].values,s=2,c=colormap[i])
        plt.title(dataSet.columns[i])
        plt.tight_layout(pad=1.2) # 调节子图间的距离
# 运行函数,查看数据分布:
dataPlot(abalone)
'c' argument looks like a single numeric RGB or RGBA sequence, which should be avoided as value-mapping will have precedence in case its length matches with 'x' & 'y'.  Please use a 2-D array with a single row if you really want to specify the same RGB or RGBA value for all points.
'c' argument looks like a single numeric RGB or RGBA sequence, which should be avoided as value-mapping will have precedence in case its length matches with 'x' & 'y'.  Please use a 2-D array with a single row if you really want to specify the same RGB or RGBA value for all points.
'c' argument looks like a single numeric RGB or RGBA sequence, which should be avoided as value-mapping will have precedence in case its length matches with 'x' & 'y'.  Please use a 2-D array with a single row if you really want to specify the same RGB or RGBA value for all points.
'c' argument looks like a single numeric RGB or RGBA sequence, which should be avoided as value-mapping will have precedence in case its length matches with 'x' & 'y'.  Please use a 2-D array with a single row if you really want to specify the same RGB or RGBA value for all points.
'c' argument looks like a single numeric RGB or RGBA sequence, which should be avoided as value-mapping will have precedence in case its length matches with 'x' & 'y'.  Please use a 2-D array with a single row if you really want to specify the same RGB or RGBA value for all points.
'c' argument looks like a single numeric RGB or RGBA sequence, which should be avoided as value-mapping will have precedence in case its length matches with 'x' & 'y'.  Please use a 2-D array with a single row if you really want to specify the same RGB or RGBA value for all points.
'c' argument looks like a single numeric RGB or RGBA sequence, which should be avoided as value-mapping will have precedence in case its length matches with 'x' & 'y'.  Please use a 2-D array with a single row if you really want to specify the same RGB or RGBA value for all points.
'c' argument looks like a single numeric RGB or RGBA sequence, which should be avoided as value-mapping will have precedence in case its length matches with 'x' & 'y'.  Please use a 2-D array with a single row if you really want to specify the same RGB or RGBA value for all points.
'c' argument looks like a single numeric RGB or RGBA sequence, which should be avoided as value-mapping will have precedence in case its length matches with 'x' & 'y'.  Please use a 2-D array with a single row if you really want to specify the same RGB or RGBA value for all points.

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-sUDRrFEr-1593153198969)(output_10_1.png)]

可以从数据分布散点图中看出:

1)除“性别”之外,其他数据明显存在规律性排列

2)“高度”这一特征中,有两个异常值

从看到的现象,我们可以采取以下两种措施:

1) 切分训练集和测试集时,需要打乱原始数据集来进行随机挑选

2) 剔除"高度"这一特征中的异常值

abalone['高度']<0.4
0       True
1       True
2       True
3       True
4       True
        ... 
4172    True
4173    True
4174    True
4175    True
4176    True
Name: 高度, Length: 4177, dtype: bool
aba = abalone.loc[abalone['高度']<0.4,:]
#再次查看数据集的分布
dataPlot(aba)
'c' argument looks like a single numeric RGB or RGBA sequence, which should be avoided as value-mapping will have precedence in case its length matches with 'x' & 'y'.  Please use a 2-D array with a single row if you really want to specify the same RGB or RGBA value for all points.
'c' argument looks like a single numeric RGB or RGBA sequence, which should be avoided as value-mapping will have precedence in case its length matches with 'x' & 'y'.  Please use a 2-D array with a single row if you really want to specify the same RGB or RGBA value for all points.
'c' argument looks like a single numeric RGB or RGBA sequence, which should be avoided as value-mapping will have precedence in case its length matches with 'x' & 'y'.  Please use a 2-D array with a single row if you really want to specify the same RGB or RGBA value for all points.
'c' argument looks like a single numeric RGB or RGBA sequence, which should be avoided as value-mapping will have precedence in case its length matches with 'x' & 'y'.  Please use a 2-D array with a single row if you really want to specify the same RGB or RGBA value for all points.
'c' argument looks like a single numeric RGB or RGBA sequence, which should be avoided as value-mapping will have precedence in case its length matches with 'x' & 'y'.  Please use a 2-D array with a single row if you really want to specify the same RGB or RGBA value for all points.
'c' argument looks like a single numeric RGB or RGBA sequence, which should be avoided as value-mapping will have precedence in case its length matches with 'x' & 'y'.  Please use a 2-D array with a single row if you really want to specify the same RGB or RGBA value for all points.
'c' argument looks like a single numeric RGB or RGBA sequence, which should be avoided as value-mapping will have precedence in case its length matches with 'x' & 'y'.  Please use a 2-D array with a single row if you really want to specify the same RGB or RGBA value for all points.
'c' argument looks like a single numeric RGB or RGBA sequence, which should be avoided as value-mapping will have precedence in case its length matches with 'x' & 'y'.  Please use a 2-D array with a single row if you really want to specify the same RGB or RGBA value for all points.
'c' argument looks like a single numeric RGB or RGBA sequence, which should be avoided as value-mapping will have precedence in case its length matches with 'x' & 'y'.  Please use a 2-D array with a single row if you really want to specify the same RGB or RGBA value for all points.

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-rhcXvPsH-1593153198971)(output_18_1.png)]

2. 切分训练集和测试集

"""
函数功能:随机切分训练集和测试集
参数说明:
    dataSet:原始数据集
    rate:训练集比例
返回:
    train,test:切分好的训练集和测试集
"""
def randSplit(dataSet,rate):
    l = list(dataSet.index) # 将原始数据集的索引提取出来,存到列表中
    random.seed(123) # 设置随机数种子
    random.shuffle(l) # 随机打乱数据集中的索引
    dataSet.index = l # 把打乱后的索引重新赋值给数据集中的索引,
    # 索引打乱了就相当于打乱了原始数据集中的数据
    m = dataSet.shape[0] # 原始数据集样本总数
    n = int(m*rate) # 训练集样本数量
    train = dataSet.loc[range(n),:] # 从打乱了的原始数据集中提取出训练集数据
    test = dataSet.loc[range(n,m),:] # 从打乱了的原始数据集中提取出测试集数据
    train.index = range(train.shape[0]) # 重置train训练数据集中的索引
    test.index = range(test.shape[0]) # 重置test测试数据集中的索引
    dataSet.index = range(dataSet.shape[0]) # 重置原始数据集中的索引
    return train,test
train,test = randSplit(aba,0.8)
#探索训练集
train.head()
性别 长度 直径 高度 整体重量 肉重量 内脏重量 壳重 年龄
0 -1 0.590 0.470 0.170 0.9000 0.3550 0.1905 0.2500 11
1 1 0.560 0.450 0.145 0.9355 0.4250 0.1645 0.2725 11
2 -1 0.635 0.535 0.190 1.2420 0.5760 0.2475 0.3900 14
3 1 0.505 0.390 0.115 0.5585 0.2575 0.1190 0.1535 8
4 1 0.510 0.410 0.145 0.7960 0.3865 0.1815 0.1955 8
train.shape
(3340, 9)
abalone.describe()
性别 长度 直径 高度 整体重量 肉重量 内脏重量 壳重 年龄
count 4177.000000 4177.000000 4177.000000 4177.000000 4177.000000 4177.000000 4177.000000 4177.000000 4177.000000
mean 0.052909 0.523992 0.407881 0.139516 0.828742 0.359367 0.180594 0.238831 9.933684
std 0.822240 0.120093 0.099240 0.041827 0.490389 0.221963 0.109614 0.139203 3.224169
min -1.000000 0.075000 0.055000 0.000000 0.002000 0.001000 0.000500 0.001500 1.000000
25% -1.000000 0.450000 0.350000 0.115000 0.441500 0.186000 0.093500 0.130000 8.000000
50% 0.000000 0.545000 0.425000 0.140000 0.799500 0.336000 0.171000 0.234000 9.000000
75% 1.000000 0.615000 0.480000 0.165000 1.153000 0.502000 0.253000 0.329000 11.000000
max 1.000000 0.815000 0.650000 1.130000 2.825500 1.488000 0.760000 1.005000 29.000000
train.describe() #统计描述
性别 长度 直径 高度 整体重量 肉重量 内脏重量 壳重 年龄
count 3340.000000 3340.000000 3340.000000 3340.000000 3340.000000 3340.000000 3340.000000 3340.000000 3340.000000
mean 0.060479 0.522754 0.406886 0.138790 0.824906 0.358151 0.179732 0.237158 9.911976
std 0.819021 0.120300 0.099372 0.038441 0.488535 0.222422 0.109036 0.137920 3.223534
min -1.000000 0.075000 0.055000 0.000000 0.002000 0.001000 0.000500 0.001500 1.000000
25% -1.000000 0.450000 0.350000 0.115000 0.439000 0.184375 0.092000 0.130000 8.000000
50% 0.000000 0.540000 0.420000 0.140000 0.796750 0.335500 0.171000 0.232000 9.000000
75% 1.000000 0.615000 0.480000 0.165000 1.147250 0.498500 0.250500 0.325000 11.000000
max 1.000000 0.780000 0.630000 0.250000 2.825500 1.488000 0.760000 1.005000 27.000000
dataPlot(train) #查看训练集数据分布
'c' argument looks like a single numeric RGB or RGBA sequence, which should be avoided as value-mapping will have precedence in case its length matches with 'x' & 'y'.  Please use a 2-D array with a single row if you really want to specify the same RGB or RGBA value for all points.
'c' argument looks like a single numeric RGB or RGBA sequence, which should be avoided as value-mapping will have precedence in case its length matches with 'x' & 'y'.  Please use a 2-D array with a single row if you really want to specify the same RGB or RGBA value for all points.
'c' argument looks like a single numeric RGB or RGBA sequence, which should be avoided as value-mapping will have precedence in case its length matches with 'x' & 'y'.  Please use a 2-D array with a single row if you really want to specify the same RGB or RGBA value for all points.
'c' argument looks like a single numeric RGB or RGBA sequence, which should be avoided as value-mapping will have precedence in case its length matches with 'x' & 'y'.  Please use a 2-D array with a single row if you really want to specify the same RGB or RGBA value for all points.
'c' argument looks like a single numeric RGB or RGBA sequence, which should be avoided as value-mapping will have precedence in case its length matches with 'x' & 'y'.  Please use a 2-D array with a single row if you really want to specify the same RGB or RGBA value for all points.
'c' argument looks like a single numeric RGB or RGBA sequence, which should be avoided as value-mapping will have precedence in case its length matches with 'x' & 'y'.  Please use a 2-D array with a single row if you really want to specify the same RGB or RGBA value for all points.
'c' argument looks like a single numeric RGB or RGBA sequence, which should be avoided as value-mapping will have precedence in case its length matches with 'x' & 'y'.  Please use a 2-D array with a single row if you really want to specify the same RGB or RGBA value for all points.
'c' argument looks like a single numeric RGB or RGBA sequence, which should be avoided as value-mapping will have precedence in case its length matches with 'x' & 'y'.  Please use a 2-D array with a single row if you really want to specify the same RGB or RGBA value for all points.
'c' argument looks like a single numeric RGB or RGBA sequence, which should be avoided as value-mapping will have precedence in case its length matches with 'x' & 'y'.  Please use a 2-D array with a single row if you really want to specify the same RGB or RGBA value for all points.

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-sIC8Ac3y-1593153198972)(output_26_1.png)]

#探索测试集
test.head() 
性别 长度 直径 高度 整体重量 肉重量 内脏重量 壳重 年龄
0 1 0.630 0.470 0.150 1.1355 0.5390 0.2325 0.3115 12
1 -1 0.585 0.445 0.140 0.9130 0.4305 0.2205 0.2530 10
2 -1 0.390 0.290 0.125 0.3055 0.1210 0.0820 0.0900 7
3 1 0.525 0.410 0.130 0.9900 0.3865 0.2430 0.2950 15
4 1 0.625 0.475 0.160 1.0845 0.5005 0.2355 0.3105 10
test.shape 
(835, 9)
test.describe() 
性别 长度 直径 高度 整体重量 肉重量 内脏重量 壳重 年龄
count 835.000000 835.000000 835.000000 835.000000 835.000000 835.000000 835.000000 835.000000 835.000000
mean 0.022754 0.528808 0.411737 0.140784 0.842714 0.363370 0.183749 0.245320 10.022754
std 0.834341 0.119166 0.098627 0.038664 0.495990 0.218938 0.111510 0.143925 3.230284
min -1.000000 0.130000 0.100000 0.015000 0.013000 0.004500 0.003000 0.004000 3.000000
25% -1.000000 0.450000 0.350000 0.115000 0.458000 0.192000 0.096500 0.132750 8.000000
50% 0.000000 0.550000 0.430000 0.140000 0.810000 0.339000 0.170500 0.235000 10.000000
75% 1.000000 0.620000 0.485000 0.170000 1.177250 0.510750 0.259250 0.337000 11.000000
max 1.000000 0.815000 0.650000 0.250000 2.555000 1.145500 0.590000 0.815000 29.000000
dataPlot(test)
'c' argument looks like a single numeric RGB or RGBA sequence, which should be avoided as value-mapping will have precedence in case its length matches with 'x' & 'y'.  Please use a 2-D array with a single row if you really want to specify the same RGB or RGBA value for all points.
'c' argument looks like a single numeric RGB or RGBA sequence, which should be avoided as value-mapping will have precedence in case its length matches with 'x' & 'y'.  Please use a 2-D array with a single row if you really want to specify the same RGB or RGBA value for all points.
'c' argument looks like a single numeric RGB or RGBA sequence, which should be avoided as value-mapping will have precedence in case its length matches with 'x' & 'y'.  Please use a 2-D array with a single row if you really want to specify the same RGB or RGBA value for all points.
'c' argument looks like a single numeric RGB or RGBA sequence, which should be avoided as value-mapping will have precedence in case its length matches with 'x' & 'y'.  Please use a 2-D array with a single row if you really want to specify the same RGB or RGBA value for all points.
'c' argument looks like a single numeric RGB or RGBA sequence, which should be avoided as value-mapping will have precedence in case its length matches with 'x' & 'y'.  Please use a 2-D array with a single row if you really want to specify the same RGB or RGBA value for all points.
'c' argument looks like a single numeric RGB or RGBA sequence, which should be avoided as value-mapping will have precedence in case its length matches with 'x' & 'y'.  Please use a 2-D array with a single row if you really want to specify the same RGB or RGBA value for all points.
'c' argument looks like a single numeric RGB or RGBA sequence, which should be avoided as value-mapping will have precedence in case its length matches with 'x' & 'y'.  Please use a 2-D array with a single row if you really want to specify the same RGB or RGBA value for all points.
'c' argument looks like a single numeric RGB or RGBA sequence, which should be avoided as value-mapping will have precedence in case its length matches with 'x' & 'y'.  Please use a 2-D array with a single row if you really want to specify the same RGB or RGBA value for all points.
'c' argument looks like a single numeric RGB or RGBA sequence, which should be avoided as value-mapping will have precedence in case its length matches with 'x' & 'y'.  Please use a 2-D array with a single row if you really want to specify the same RGB or RGBA value for all points.

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-MjIwyXmw-1593153198974)(output_30_1.png)]

3.构建辅助函数

'''
函数功能:输入DF数据集(最后一列为标签),返回特征矩阵和标签矩阵
'''
def get_Mat(dataSet):
    xMat = np.mat(dataSet.iloc[:,:-1].values)
    yMat = np.mat(dataSet.iloc[:,-1].values).T
    return xMat,yMat

'''
函数功能:数据集可视化
'''
def plotShow(dataSet):
    xMat,yMat = get_Mat(dataSet)
    plt.scatter(xMat.A[:,1],yMat.A,c='b',s=5)
    plt.show()

'''
函数功能:计算回归系数
参数说明:
    dataSet:原始数据集
返回:
    ws:回归系数
'''
def standRegres(dataSet):
    xMat,yMat = get_Mat(dataSet)
    xTx = xMat.T * xMat
    if np.linalg.det(xTx) == 0:
        print('矩阵为奇异矩阵,无法求逆!')
        return
    ws = xTx.I*(xMat.T*yMat) # xTx.I ,用来求逆矩阵
    return ws
"""
函数功能:计算误差平方和SSE
参数说明:
    dataSet:真实值
    regres:求回归系数的函数
返回:
    SSE:误差平方和
"""
def sseCal(dataSet, regres):
    xMat,yMat = get_Mat(dataSet)
    ws = regres(dataSet)
    yHat = xMat*ws
    sse = ((yMat.A.flatten() - yHat.A.flatten())**2).sum()#  
    return sse

以ex0数据集为例,查看函数运行结果:

ex0 = pd.read_table("./datas/ex0.txt",header=None)
ex0.head()
0 1 2
0 1.0 0.067732 3.176513
1 1.0 0.427810 3.816464
2 1.0 0.995731 4.550095
3 1.0 0.738336 4.256571
4 1.0 0.981083 4.560815
#简单线性回归的SSE
sseCal(ex0, standRegres)
1.3552490816814902

构建相关系数R2计算函数

"""
函数功能:计算相关系数R2
"""
def rSquare(dataSet,regres):
    xMat,yMat=get_Mat(dataSet)
    sse = sseCal(dataSet,regres)
    sst = ((yMat.A-yMat.mean())**2).sum()#  
    r2 = 1 - sse / sst
    return r2

同样以ex0数据集为例,查看函数运行结果:

#简单线性回归的R2
rSquare(ex0, standRegres)
0.9731300889856916
'''
函数功能:计算局部加权线性回归的预测值
参数说明:
    testMat:测试集
    xMat:训练集的特征矩阵
    yMat:训练集的标签矩阵
    返回:
        yHat:函数预测值
'''
def LWLR(testMat,xMat,yMat,k=1.0):
    n = testMat.shape[0] # 测试数据集行数
    m = xMat.shape[0] # 训练集特征矩阵行数
    weights = np.mat(np.eye(m)) # 用单位矩阵来初始化权重矩阵,
    yHat = np.zeros(n) # 用0矩阵来初始化预测值矩阵
    for i in range(n):
        for j in range(m):
            diffMat = testMat[i] - xMat[j]
            weights[j,j] = np.exp(diffMat*diffMat.T / (-2*k**2))
        xTx = xMat.T*(weights*xMat)
        if np.linalg.det(xTx) == 0:
            print('矩阵为奇异矩阵,无法求逆')
            return
        ws = xTx.I*(xMat.T*(weights*yMat))
        yHat[i] = testMat[i] * ws
    return ws,yHat

4.构建加权线性模型

因为数据量太大,计算速度极慢,所以此处选择训练集的前100个数据作为训练集,测试集的前100个数据作为测试集。

"""
函数功能:绘制不同k取值下,训练集和测试集的SSE曲线
"""
def ssePlot(train,test):
    X0,Y0 = get_Mat(train)
    X1,Y1 =get_Mat(test)
    train_sse = []
    test_sse = []
    for k in np.arange(0.2,10,0.5):
        ws1,yHat1 = LWLR(X0[:99],X0[:99],Y0[:99],k) 
        sse1 = ((Y0[:99].A.T - yHat1)**2).sum() 
        train_sse.append(sse1)
        
        ws2,yHat2 = LWLR(X1[:99],X0[:99],Y0[:99],k) 
        sse2 = ((Y1[:99].A.T - yHat2)**2).sum() 
        test_sse.append(sse2)
        
    plt.figure(figsize=(20,8),dpi=100)
    plt.plot(np.arange(0.2,10,0.5),train_sse,color='b')#     
    plt.plot(np.arange(0.2,10,0.5),test_sse,color='r') 
    plt.xlabel('不同k取值')
    plt.ylabel('SSE')
    plt.legend(['train_sse','test_sse'])

运行结果:

ssePlot(train,test)

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-BXGhyRcs-1593153198975)(output_47_0.png)]

这个图的解读应该是这样的:从右往左看,当K取较大值时,模型比较稳定,随着K值的减小,训练集的SSE开始逐渐减小,当K取到2左右,训练集的SSE与测试集的SSE相等,当K继续减小时,训练集的SSE也越来越小,也就是说,模型在训练集上的表现越来越好,但是,模型在测试集上的表现却越来越差了,这就说明模型开始出现过拟合了。其实,这个图与前面不同k值的结果图是吻合的,K=1.0,
0.01, 0.003这三张图也表明随着K的减小,模型会逐渐出现过拟合。所以这里可以看出,K在2左右的取值最佳。

我们再将K=2带入局部线性回归模型中,然后查看预测结果:

train,test = randSplit(aba,0.8) # 随机切分原始数据集,得到训练集和测试集
trainX,trainY = get_Mat(train) # 将切分好的训练集分成特征矩阵和标签矩阵
testX,testY = get_Mat(test) # 将切分好的测试集分成特征矩阵和标签矩阵
ws0,yHat0 = LWLR(testX,trainX,trainY,k=2)

绘制真实值与预测值之间的关系图

y=testY.A.flatten()
plt.scatter(y,yHat0,c='b',s=5); # ;等效于plt.show()

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-y9Wfstwl-1593153198976)(output_52_0.png)]

通过上图可知,横坐标为真实值,纵坐标为预测值,形成的图像为呈现一个“喇叭形”,随着横坐标真实值逐渐变大,纵坐标预测值也越来越大,说明随着真实值的增加,预测值偏差越来越大

封装一个函数来计算SSE和R方,方便后续调用

"""
函数功能:计算加权线性回归的SSE和R方
"""
def LWLR_pre(dataSet):
    train,test = randSplit(dataSet,0.8)#      
    trainX,trainY = get_Mat(train)
    testX,testY = get_Mat(test)
    ws,yHat = LWLR(testX,trainX,trainY,k=2)#     
    sse = ((testY.A.T - yHat)**2).sum()#     
    sst = ((testY.A-testY.mean())**2).sum() #     
    r2 = 1 - sse / sst
    return sse,r2

查看模型预测结果

LWLR_pre(aba)
(4152.777097646255, 0.5228101340130846)

从结果可以看出,SSE达4000+,相关系数只有0.52,模型效果并不是很好。

猜你喜欢

转载自blog.csdn.net/weixin_44827418/article/details/106969645