机器学习sklearn(1)简单线性回归

本次练习主要涉及到知识点有:数据读取、画散点图、数据预处理、模型调用、模型评估,流程如下:

引入库

import matplotlib.pyplot as plt
import pandas as pd
from sklearn.linear_model import LinearRegression
from sklearn.model_selection import cross_val_score    #交叉验证
%matplotlib inline   

data= pd.read_csv('ex0.txt', sep='\t', header=None)   #读取数据
print(data.head())

   0         1         2
0  1.0  0.067732  3.176513
1  1.0  0.427810  3.816464
2  1.0  0.995731  4.550095
3  1.0  0.738336  4.256571
4  1.0  0.981083  4.560815

#绘制散点图
plt.scatter(data[1],data[2], s=20, c='blue', alpha=.5)   #alpha:透明度(更改透明度,可以看出哪些点集中,那些点不集中)  s:点的大小
plt.title('Dataset')    #标题
plt.xlabel('X')     #x轴
plt.show()

regr = LinearRegression()
data_data = data[[0,1]].values    #把特征数据(dataframe前两列)转换成array格式
data_target = data[2].values

#把数据分成训练集和测试集
data_x_train = data_data[:-10]
data_x_test = data_data[-10:]
data_y_train = data_target[:-10]
data_y_test = data_target[-10:]
regr.fit(data_x_train, data_y_train)

LinearRegression(copy_X=True, fit_intercept=True, n_jobs=None,
         normalize=False)

print("Coefficients: \n", regr.coef_)    #系数

Coefficients: 
 [0.         1.69857344]

print("Intercept: \n", regr.intercept_)   #截距

Intercept: 
 3.0045212594826105

data_y_pred = regr.predict(data_x_test)    #对测试集进行预测
print(data_y_pred)

[4.48233111 4.03174204 3.43801758 4.24540729 3.84332608 3.44108351
 4.42068139 3.12358276 3.89978836 3.20183265]

#均方误差
from sklearn.metrics import mean_squared_error, r2_score
print("Mean square error: %.2f" % mean_squared_error(data_y_test, data_y_pred))

Mean square error: 0.01

#方差分析
print("Variance score: %.2f" % r2_score(data_y_test, data_y_pred))

Variance score: 0.97

regr.score(data_x_test, data_y_test)     #用测试数据评价模型

0.9666641890737667

plt.plot(data[1].values, regr.predict(data_data), c='blue', linewidth=3)   
plt.scatter(data[1], data_target, c='black')
plt.show()

scores = cross_val_score(regr, data_data, data_target, cv=10)  #交叉验证
print(scores)

[0.97221449 0.97779641 0.93482591 0.97954816 0.97779962 0.95599549
 0.96683191 0.97480362 0.98098483 0.97303444]

scores.mean()

0.9693834864847795

猜你喜欢

转载自blog.csdn.net/weixin_44530236/article/details/88560524