sklearn achieve linear regression model to predict prices

Questions asked

Rate prediction model established: the use ex1data1.txt( single feature ) and ex1data2.txt( multi-feature data) is, linear regression and prediction.

Known as scatter plots, data roughly in line with a linear relationship, it is temporarily study other forms of return.

Two data at the end.

Single characteristic linear regression

ex1data1.txtWherein the data is single, as can be a simple linear regression: \ (Y = AX + B \) .

Divided according to whether data to produce two schemes: Scheme I, all samples are used for training and prediction; Scheme II, the training portion of the sample to a portion of the model used to test.

Option One

Data ex1data1.txt the linear regression, all samples are used for training and forecasting.

Code is implemented as follows:

"""
    对ex1data1.txt中的数据进行线性回归,所有样本都用来训练和预测
"""
import numpy as np
import matplotlib.pyplot as plt
from sklearn.linear_model import LinearRegression
from sklearn.metrics import mean_squared_error, r2_score
plt.rcParams['font.sans-serif'] = ['SimHei']  # 用来正常显示中文标签
plt.rcParams['axes.unicode_minus'] = False  # 用来正常显示负号

# 数据格式:城市人口,食品经销商利润

# 读取数据
data = np.loadtxt('ex1data1.txt', delimiter=',')
data_X = data[:, 0]
data_y = data[:, 1]

# 训练模型
model = LinearRegression()
model.fit(data_X.reshape([-1, 1]), data_y)

# 利用模型进行预测
y_predict = model.predict(data_X.reshape([-1, 1]))

# 结果可视化
plt.scatter(data_X, data_y, color='red')
plt.plot(data_X, y_predict, color='blue', linewidth=3)
plt.xlabel('城市人口')
plt.ylabel('食品经销商利润')
plt.title('线性回归——城市人口与食品经销商利润的关系')
plt.show()

# 模型参数
print(model.coef_)
print(model.intercept_)
# MSE
print(mean_squared_error(data_y, y_predict))
# R^2
print(r2_score(data_y, y_predict))

The results are as follows:

Seen from the form of the function and \ (R ^ 2 \) 0.70

[1.19303364]
-3.89578087831185
8.953942751950358
0.7020315537841397

ex1data1_1.png

Option II

Ex1data1.txt data of linear regression, to the training portion of the sample, some of the samples to predict.

To achieve the following:

"""
    对ex1data1.txt中的数据进行线性回归,部分样本用来训练,部分样本用来预测
"""
import numpy as np
import matplotlib.pyplot as plt
from sklearn.linear_model import LinearRegression
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error, r2_score
plt.rcParams['font.sans-serif'] = ['SimHei']  # 用来正常显示中文标签
plt.rcParams['axes.unicode_minus'] = False  # 用来正常显示负号

# 数据格式:城市人口,食品经销商利润

# 读取数据
data = np.loadtxt('ex1data1.txt', delimiter=',')
data_X = data[:, 0]
data_y = data[:, 1]

# 数据分割
X_train, X_test, y_train, y_test = train_test_split(data_X, data_y)

# 训练模型
model = LinearRegression()
model.fit(X_train.reshape([-1, 1]), y_train)

# 利用模型进行预测
y_predict = model.predict(X_test.reshape([-1, 1]))

# 结果可视化
plt.scatter(X_test, y_test, color='red')  # 测试样本
plt.plot(X_test, y_predict, color='blue', linewidth=3)
plt.xlabel('城市人口')
plt.ylabel('食品经销商利润')
plt.title('线性回归——城市人口与食品经销商利润的关系')
plt.show()

# 模型参数
print(model.coef_)
print(model.intercept_)
# MSE
print(mean_squared_error(y_test, y_predict))
# R^2
print(r2_score(y_test, y_predict))

The results are as follows :

Seen from the form of the function and \ (R ^ 2 \) 0.80

[1.21063939]
-4.195481965945055
5.994362667047617
0.8095125123727652

ex1data1_2.png

Multi-feature linear regression

ex1data2.txtSecond feature data is a, for a simple polyhydric (dihydric here) to linear regression: \ (Y + B = a_1x_1 a_2x_2 + \) .

Data ex1data2.txt the linear regression, all samples are used for training and forecasting.

Code is implemented as follows:

"""
    对ex1data2.txt中的数据进行线性回归,所有样本都用来训练和预测
"""
import numpy as np
import matplotlib.pyplot as plt
from sklearn.linear_model import LinearRegression
from mpl_toolkits.mplot3d import Axes3D  # 不要去掉这个import
from sklearn.metrics import mean_squared_error, r2_score
plt.rcParams['font.sans-serif'] = ['SimHei']  # 用来正常显示中文标签
plt.rcParams['axes.unicode_minus'] = False  # 用来正常显示负号

# 数据格式:城市人口,房间数目,房价

# 读取数据
data = np.loadtxt('ex1data2.txt', delimiter=',')
data_X = data[:, 0:2]
data_y = data[:, 2]

# 训练模型
model = LinearRegression()
model.fit(data_X, data_y)

# 利用模型进行预测
y_predict = model.predict(data_X)

# 结果可视化
fig = plt.figure()
ax = fig.gca(projection='3d')
ax.scatter(data_X[:, 0], data_X[:, 1], data_y, color='red')
ax.plot(data_X[:, 0], data_X[:, 1], y_predict, color='blue')
ax.set_xlabel('城市人口')
ax.set_ylabel('房间数目')
ax.set_zlabel('房价')
plt.title('线性回归——城市人口、房间数目与房价的关系')
plt.show()

# 模型参数
print(model.coef_)
print(model.intercept_)
# MSE
print(mean_squared_error(data_y, y_predict))
# R^2
print(r2_score(data_y, y_predict))

The results are as follows:

Seen from the form of the function and \ (R ^ 2 \) 0.73

[  139.21067402 -8738.01911233]
89597.90954279748
4086560101.205658
0.7329450180289141

ex1data2.png

Two data

ex1data1.txt

6.1101,17.592
5.5277,9.1302
8.5186,13.662
7.0032,11.854
5.8598,6.8233
8.3829,11.886
7.4764,4.3483
8.5781,12
6.4862,6.5987
5.0546,3.8166
5.7107,3.2522
14.164,15.505
5.734,3.1551
8.4084,7.2258
5.6407,0.71618
5.3794,3.5129
6.3654,5.3048
5.1301,0.56077
6.4296,3.6518
7.0708,5.3893
6.1891,3.1386
20.27,21.767
5.4901,4.263
6.3261,5.1875
5.5649,3.0825
18.945,22.638
12.828,13.501
10.957,7.0467
13.176,14.692
22.203,24.147
5.2524,-1.22
6.5894,5.9966
9.2482,12.134
5.8918,1.8495
8.2111,6.5426
7.9334,4.5623
8.0959,4.1164
5.6063,3.3928
12.836,10.117
6.3534,5.4974
5.4069,0.55657
6.8825,3.9115
11.708,5.3854
5.7737,2.4406
7.8247,6.7318
7.0931,1.0463
5.0702,5.1337
5.8014,1.844
11.7,8.0043
5.5416,1.0179
7.5402,6.7504
5.3077,1.8396
7.4239,4.2885
7.6031,4.9981
6.3328,1.4233
6.3589,-1.4211
6.2742,2.4756
5.6397,4.6042
9.3102,3.9624
9.4536,5.4141
8.8254,5.1694
5.1793,-0.74279
21.279,17.929
14.908,12.054
18.959,17.054
7.2182,4.8852
8.2951,5.7442
10.236,7.7754
5.4994,1.0173
20.341,20.992
10.136,6.6799
7.3345,4.0259
6.0062,1.2784
7.2259,3.3411
5.0269,-2.6807
6.5479,0.29678
7.5386,3.8845
5.0365,5.7014
10.274,6.7526
5.1077,2.0576
5.7292,0.47953
5.1884,0.20421
6.3557,0.67861
9.7687,7.5435
6.5159,5.3436
8.5172,4.2415
9.1802,6.7981
6.002,0.92695
5.5204,0.152
5.0594,2.8214
5.7077,1.8451
7.6366,4.2959
5.8707,7.2029
5.3054,1.9869
8.2934,0.14454
13.394,9.0551
5.4369,0.61705

ex1data2.txt

2104,3,399900
1600,3,329900
2400,3,369000
1416,2,232000
3000,4,539900
1985,4,299900
1534,3,314900
1427,3,198999
1380,3,212000
1494,3,242500
1940,4,239999
2000,3,347000
1890,3,329999
4478,5,699900
1268,3,259900
2300,4,449900
1320,2,299900
1236,3,199900
2609,4,499998
3031,4,599000
1767,3,252900
1888,2,255000
1604,3,242900
1962,4,259900
3890,3,573900
1100,3,249900
1458,3,464500
2526,3,469000
2200,3,475000
2637,3,299900
1839,2,349900
1000,1,169900
2040,4,314900
3137,3,579900
1811,4,285900
1437,3,249900
1239,3,229900
2132,4,345000
4215,4,549000
2162,4,287000
1664,2,368500
2238,3,329900
2567,4,314000
1200,3,299000
852,2,179900
1852,4,299900
1203,3,239500

Author: @ smelly salted fish

Please indicate the source: https://www.cnblogs.com/chouxianyu/

Welcome to discuss and share!


Guess you like

Origin www.cnblogs.com/chouxianyu/p/11704665.html