Scikit-learn 学习心得(线性回归)

一开始学习线性回归的时候,是在MATLAB里写的,现在学习用python来做。

Scikit-learn 的线性回归怎么用呢?

开始参考的是这个:https://blog.csdn.net/u010900574/article/details/52666291

官方参考为:http://scikit-learn.org/stable/modules/linear_model.html

其实很简单,主要有如下几步:

1,各种import

import matplotlib.pyplot as plt
import numpy as np
from sklearn.linear_model import LinearRegression

2,构造样本

#这里尝试做一个直线10 + 2x的回归

NUM = 10
x = np.linspace(1,10,NUM)
y = 10 + 2*x + np.random.randn(NUM)
x_train = np.array([[tmp] for tmp in x]) #sklearn线性回归函数的入参要整成这种,不能用前两行的x,(这个在一开始把我整的晕乎乎的,还不太熟悉python的数据类型)
y_train = np.array([[tmp] for tmp in y])

3,训练

model_linear = LinearRegression()  #设置线性回归模块
model_linear.fit(x_train, y_train) #训练数据,得出参数

4,得到特征参数:

print(model_linear.coef_)        #coef_ 用于存放系数
print(model_linear.intercept_)   #intercept_ 用于存放截距 (偏置项)

其他参数请查阅相关资料,例如:https://www.cnblogs.com/magle/p/5881170.html

5,模型应用:

y_predict = model_linear.predict(x_train)

6,绘图

label = ['sample', 'predict']
plt.plot(x_train, y_train, 'o')
plt.plot(x_train, y_predict)
plt.legend(label)
plt.show()

完成。

猜你喜欢

转载自blog.csdn.net/a274767172/article/details/82561744
今日推荐