机器学习入门----线性回归实验记录

机器学习入门:线性回归实验记录

1、实验描述

  • 提供一份关于产品广告费用与对应产品销量的数据文件Advertising.csv文件,利用此文件建立线性模型、训练模型、用模型做预测分析。(文件数据详见附录数据集)
  • 主要步骤:
    • 加载csv文件
    • 获得标签和特征数据
    • 展示标签和特征的关系图
    • 切分数据集
    • 创建模型
    • 用模型做预测
    • 模型评估

2、相关技能

  • Python编程
  • Pandas编程
  • Sklearn的使用
  • 线性回归建模
  • 用matplotlib 绘图

3、相关知识

  • Pandas 读取csv文件
  • Pandas读取特征、标签数据
  • 数据集进行划分
  • 线性模型
  • 模型预测
  • 模型评估

4、实现效果

  • 利用线性回归模型对测试集数据做预测,下图展示了实际销售量和预测销量的拟合效果:

在这里插入图片描述

5、实验步骤

5.1导入实验所需的包

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LinearRegression
from sklearn.metrics import mean_squared_error # 均方误差

5.2读取数据文件

path = "file/Advertising.csv"
data = pd.read_csv(path)   # 读取csv文件

5.3打印文件的前几行

print(data.head(10))

在这里插入图片描述

5.4显示文件的shape

print(data.shape)

在这里插入图片描述

5.5使用pandas读取相应的维度分别作为特征值X, 和标签值Y

x = data[['TV', 'Radio', 'Newspaper']]
y = data['Sales']

5.6绘制不同特征和标签的关系

plt.figure(figsize=(9, 12))  #图示的大小
plt.subplot(311)         # 子图位于全图3行1列中的的第一个位置
plt.plot(data['TV'], y, 'ro')  # 子图的横纵坐标的两个维度;ro:其中r表示线条的颜色;o表示红色和组成(圆圈);具体可参考下图
plt.title('TV')             # 子图的title
plt.grid()                    # 生成网格
plt.subplot(312)            # 类似
plt.plot(data['Radio'], y, 'b*')
plt.title('Radio')
plt.grid()
plt.subplot(313)
plt.plot(data['Newspaper'], y, 'g^')    # g^:表示绿色的,三角形
plt.title('Newspaper')
plt.grid()
plt.show()

下图中列出了不同字符所代表的线或者marker的样式

在这里插入图片描述

在这里插入图片描述

5.7分析上边结果图,在报纸“Newspaper”上所花广告费用与商品的销量不成线性相关的,所以后面建模时,可以尝试删掉该特征。

x=data[['TV','Radio']]

5.8使用sklearn自带的数据预处理模块对数据集进行切分,构建训练集和测试集,比例为7比3

x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.3, random_state=23)

5.9使用sklearn的线性回归类建模,参考normalize=True表示指定对训练数据进行正则化操作;n_jobs=-1表示使用所有的cpu进行训练。

lr = LinearRegression(normalize=True, n_jobs=-1)
model = lr.fit(x_train, y_train)    # 利用训练数据,训练模型

5.10打印模型的相关参数

print(lr.intercept_)  # 打印线性模型的截距值    
print(lr.coef_)       # 返回模型的估计系数

5.11使用训练好的模型进行预测

y_pred = model.predict(x_test)

5.12使用RMSE(标准误差)对模型进行评估

mse = mean_squared_error(y_test, y_pred) # 传入实际的标签值y_test,和预测的标签值y_pred
print("MSE  : ",mse)    # MSE 均方误差    
print("RMSE :" ,np.sqrt(mse))    # 标准误差

在这里插入图片描述

6.4将标签的实际值和预测值用图展示出来,直观的观察拟合程度。

plt.figure()    
plt.plot(range(len(y_pred)), y_pred, 'b', label='predict')    
plt.plot(range(len(y_test)), y_test, 'r', label='test')    
plt.legend(loc='upper right') #标签的显示位置 右上角。    
plt.xlabel("the num of sales")  # x轴标签    
plt.ylabel("value of sales")     # y轴标签    
plt.title("sales real with pred")   # 图像的title    
plt.show()

在这里插入图片描述

7、参考答案

  • 代码清单lr1.py

在这里插入图片描述

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LinearRegression
from sklearn.metrics import mean_squared_error  # 均方误差
import sys

path = "file/Advertising.csv"
data = pd.read_csv(path)
print(data.head(10))
print(data.shape)
x = data[['TV', 'Radio', 'Newspaper']]
y = data['Sales']
# plt.figure(figsize=(9, 12))  #图示的大小
# plt.subplot(311)         # 子图位于全图3行1列中的的第一个位置
# plt.plot(data['TV'], y, 'ro')  # 子图的横纵坐标的两个维度;ro:其中r表示线条的颜色;o表示红色和组成(圆圈);具体可参考下图
# plt.title('TV')             # 子图的title
# plt.grid()                    # 生成网格
# plt.subplot(312)            # 类似
# plt.plot(data['Radio'], y, 'b*')
# plt.title('Radio')
# plt.grid()
# plt.subplot(313)
# plt.plot(data['Newspaper'], y, 'g^')    # g^:表示绿色的,三角形
# plt.title('Newspaper')
# plt.grid()
# plt.show()


x=data[['TV','Radio']]
#使用sklearn自带的数据预处理模块对数据集进行切分,构建训练集和测试集,比例为7比3
x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.3, random_state=23)
#使用sklearn的线性回归类建模,参考normalize=True表示指定对训练数据进行正则化操作;n_jobs=-1表示使用所有的cpu进行训练。
lr = LinearRegression(normalize=True, n_jobs=-1)
model = lr.fit(x_train, y_train)    # 利用训练数据,训练模型
print(lr.intercept_)  # 打印线性模型的截距值
print(lr.coef_)
y_pred = model.predict(x_test)
mse = mean_squared_error(y_test, y_pred) # 传入实际的标签值y_test,和预测的标签值y_pred
print("MSE  : ",mse)    # MSE 均方误差
print("RMSE :" ,np.sqrt(mse))    # 标准误差
plt.figure()
plt.plot(range(len(y_pred)), y_pred, 'b', label='predict')
plt.plot(range(len(y_test)), y_test, 'r', label='test')
plt.legend(loc='upper right') #标签的显示位置 右上角。
plt.xlabel("the num of sales")  # x轴标签
plt.ylabel("value of sales")     # y轴标签
plt.title("sales real with pred")   # 图像的title
plt.show()
# print(data)

8、总结

完成本次实验,可以掌握线性回归模型的基础知识,包括理论与动手编程两方面,其中编程部分涉及模型的构建、训练,以及使用matplotlib对结果进行可视化,观察不同的特征对标签的实际影响。在建模过程中,可以先删减与target呈非线性相关的特征,再建立模型、训练模型。

附录:

数据集

,TV,Radio,Newspaper,Sales
1,230.1,37.8,69.2,22.1
2,44.5,39.3,45.1,10.4
3,17.2,45.9,69.3,9.3
4,151.5,41.3,58.5,18.5
5,180.8,10.8,58.4,12.9
6,8.7,48.9,75,7.2
7,57.5,32.8,23.5,11.8
8,120.2,19.6,11.6,13.2
9,8.6,2.1,1,4.8
10,199.8,2.6,21.2,10.6
11,66.1,5.8,24.2,8.6
12,214.7,24,4,17.4
13,23.8,35.1,65.9,9.2
14,97.5,7.6,7.2,9.7
15,204.1,32.9,46,19
16,195.4,47.7,52.9,22.4
17,67.8,36.6,114,12.5
18,281.4,39.6,55.8,24.4
19,69.2,20.5,18.3,11.3
20,147.3,23.9,19.1,14.6
21,218.4,27.7,53.4,18
22,237.4,5.1,23.5,12.5
23,13.2,15.9,49.6,5.6
24,228.3,16.9,26.2,15.5
25,62.3,12.6,18.3,9.7
26,262.9,3.5,19.5,12
27,142.9,29.3,12.6,15
28,240.1,16.7,22.9,15.9
29,248.8,27.1,22.9,18.9
30,70.6,16,40.8,10.5
31,292.9,28.3,43.2,21.4
32,112.9,17.4,38.6,11.9
33,97.2,1.5,30,9.6
34,265.6,20,0.3,17.4
35,95.7,1.4,7.4,9.5
36,290.7,4.1,8.5,12.8
37,266.9,43.8,5,25.4
38,74.7,49.4,45.7,14.7
39,43.1,26.7,35.1,10.1
40,228,37.7,32,21.5
41,202.5,22.3,31.6,16.6
42,177,33.4,38.7,17.1
43,293.6,27.7,1.8,20.7
44,206.9,8.4,26.4,12.9
45,25.1,25.7,43.3,8.5
46,175.1,22.5,31.5,14.9
47,89.7,9.9,35.7,10.6
48,239.9,41.5,18.5,23.2
49,227.2,15.8,49.9,14.8
50,66.9,11.7,36.8,9.7
51,199.8,3.1,34.6,11.4
52,100.4,9.6,3.6,10.7
53,216.4,41.7,39.6,22.6
54,182.6,46.2,58.7,21.2
55,262.7,28.8,15.9,20.2
56,198.9,49.4,60,23.7
57,7.3,28.1,41.4,5.5
58,136.2,19.2,16.6,13.2
59,210.8,49.6,37.7,23.8
60,210.7,29.5,9.3,18.4
61,53.5,2,21.4,8.1
62,261.3,42.7,54.7,24.2
63,239.3,15.5,27.3,15.7
64,102.7,29.6,8.4,14
65,131.1,42.8,28.9,18
66,69,9.3,0.9,9.3
67,31.5,24.6,2.2,9.5
68,139.3,14.5,10.2,13.4
69,237.4,27.5,11,18.9
70,216.8,43.9,27.2,22.3
71,199.1,30.6,38.7,18.3
72,109.8,14.3,31.7,12.4
73,26.8,33,19.3,8.8
74,129.4,5.7,31.3,11
75,213.4,24.6,13.1,17
76,16.9,43.7,89.4,8.7
77,27.5,1.6,20.7,6.9
78,120.5,28.5,14.2,14.2
79,5.4,29.9,9.4,5.3
80,116,7.7,23.1,11
81,76.4,26.7,22.3,11.8
82,239.8,4.1,36.9,12.3
83,75.3,20.3,32.5,11.3
84,68.4,44.5,35.6,13.6
85,213.5,43,33.8,21.7
86,193.2,18.4,65.7,15.2
87,76.3,27.5,16,12
88,110.7,40.6,63.2,16
89,88.3,25.5,73.4,12.9
90,109.8,47.8,51.4,16.7
91,134.3,4.9,9.3,11.2
92,28.6,1.5,33,7.3
93,217.7,33.5,59,19.4
94,250.9,36.5,72.3,22.2
95,107.4,14,10.9,11.5
96,163.3,31.6,52.9,16.9
97,197.6,3.5,5.9,11.7
98,184.9,21,22,15.5
99,289.7,42.3,51.2,25.4
100,135.2,41.7,45.9,17.2
101,222.4,4.3,49.8,11.7
102,296.4,36.3,100.9,23.8
103,280.2,10.1,21.4,14.8
104,187.9,17.2,17.9,14.7
105,238.2,34.3,5.3,20.7
106,137.9,46.4,59,19.2
107,25,11,29.7,7.2
108,90.4,0.3,23.2,8.7
109,13.1,0.4,25.6,5.3
110,255.4,26.9,5.5,19.8
111,225.8,8.2,56.5,13.4
112,241.7,38,23.2,21.8
113,175.7,15.4,2.4,14.1
114,209.6,20.6,10.7,15.9
115,78.2,46.8,34.5,14.6
116,75.1,35,52.7,12.6
117,139.2,14.3,25.6,12.2
118,76.4,0.8,14.8,9.4
119,125.7,36.9,79.2,15.9
120,19.4,16,22.3,6.6
121,141.3,26.8,46.2,15.5
122,18.8,21.7,50.4,7
123,224,2.4,15.6,11.6
124,123.1,34.6,12.4,15.2
125,229.5,32.3,74.2,19.7
126,87.2,11.8,25.9,10.6
127,7.8,38.9,50.6,6.6
128,80.2,0,9.2,8.8
129,220.3,49,3.2,24.7
130,59.6,12,43.1,9.7
131,0.7,39.6,8.7,1.6
132,265.2,2.9,43,12.7
133,8.4,27.2,2.1,5.7
134,219.8,33.5,45.1,19.6
135,36.9,38.6,65.6,10.8
136,48.3,47,8.5,11.6
137,25.6,39,9.3,9.5
138,273.7,28.9,59.7,20.8
139,43,25.9,20.5,9.6
140,184.9,43.9,1.7,20.7
141,73.4,17,12.9,10.9
142,193.7,35.4,75.6,19.2
143,220.5,33.2,37.9,20.1
144,104.6,5.7,34.4,10.4
145,96.2,14.8,38.9,11.4
146,140.3,1.9,9,10.3
147,240.1,7.3,8.7,13.2
148,243.2,49,44.3,25.4
149,38,40.3,11.9,10.9
150,44.7,25.8,20.6,10.1
151,280.7,13.9,37,16.1
152,121,8.4,48.7,11.6
153,197.6,23.3,14.2,16.6
154,171.3,39.7,37.7,19
155,187.8,21.1,9.5,15.6
156,4.1,11.6,5.7,3.2
157,93.9,43.5,50.5,15.3
158,149.8,1.3,24.3,10.1
159,11.7,36.9,45.2,7.3
160,131.7,18.4,34.6,12.9
161,172.5,18.1,30.7,14.4
162,85.7,35.8,49.3,13.3
163,188.4,18.1,25.6,14.9
164,163.5,36.8,7.4,18
165,117.2,14.7,5.4,11.9
166,234.5,3.4,84.8,11.9
167,17.9,37.6,21.6,8
168,206.8,5.2,19.4,12.2
169,215.4,23.6,57.6,17.1
170,284.3,10.6,6.4,15
171,50,11.6,18.4,8.4
172,164.5,20.9,47.4,14.5
173,19.6,20.1,17,7.6
174,168.4,7.1,12.8,11.7
175,222.4,3.4,13.1,11.5
176,276.9,48.9,41.8,27
177,248.4,30.2,20.3,20.2
178,170.2,7.8,35.2,11.7
179,276.7,2.3,23.7,11.8
180,165.6,10,17.6,12.6
181,156.6,2.6,8.3,10.5
182,218.5,5.4,27.4,12.2
183,56.2,5.7,29.7,8.7
184,287.6,43,71.8,26.2
185,253.8,21.3,30,17.6
186,205,45.1,19.6,22.6
187,139.5,2.1,26.6,10.3
188,191.1,28.7,18.2,17.3
189,286,13.9,3.7,15.9
190,18.7,12.1,23.4,6.7
191,39.5,41.1,5.8,10.8
192,75.5,10.8,6,9.9
193,17.2,4.1,31.6,5.9
194,166.8,42,3.6,19.6
195,149.7,35.6,6,17.3
196,38.2,3.7,13.8,7.6
197,94.2,4.9,8.1,9.7
198,177,9.3,6.4,12.8
199,283.6,42,66.2,25.5
200,232.1,8.6,8.7,13.4

猜你喜欢

转载自blog.csdn.net/Twinkle_sone/article/details/108653730