版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/zhaohaibo_/article/details/81903389
今天的第一个任务是拟合三个点。
三个点。。
复习了一下最小二乘法(线性拟合)
- 损失函数:残差平方和
- 评价指标:MSE(均方误差)、R2(方差得分)
此示例使用糖尿病数据集的一特征,计算了系数、残差平方和与方差得分。展示了线性回归方法的二维绘图。
图中显示了线性回归拟合出的一条直线,它将使数据集中观测到的响应之间的残差平方和和线性近似预测的响应最佳化。
import matplotlib.pyplot as plt
import numpy as np
from sklearn import datasets, linear_model
from sklearn.metrics import mean_squared_error, r2_score
class Linear(object):
def main(self):
# 使用diabetes dataset
diabetes = datasets.load_diabetes()
# 使用数据集的第二列
# 如果是data[:2],将生成Size为(3,)的矩阵
# np.newaxis可以为矩阵(多维数组)增加一个轴,生成Size为(3,1)的矩阵
diabetes_X = diabetes.data[:, np.newaxis, 2]
# 后20条数据 为训练集
diabetes_X_train = diabetes_X[:-20]
diabetes_X_test = diabetes_X[-20:]
# 标定Label
diabetes_y_train = diabetes.target[:-20]
diabetes_y_test = diabetes.target[-20:]
# 线性回归模型
regr = linear_model.LinearRegression()
# 训练
regr.fit(diabetes_X_train, diabetes_y_train)
# 预测
diabetes_y_pred = regr.predict(diabetes_X_test)
# 系数
print('Coefficients: \n', regr.coef_)
# 均方误差
print("Mean squared error: %.2f"
% mean_squared_error(diabetes_y_test, diabetes_y_pred))
# 方差得分:1是完美预测
print('Variance score: %.2f' % r2_score(diabetes_y_test, diabetes_y_pred))
# 画图
plt.scatter(diabetes_X_test, diabetes_y_test, color='black')
plt.plot(diabetes_X_test, diabetes_y_pred, color='blue', linewidth=3)
plt.savefig("linear.jpg")
# plt.xticks(())
# plt.yticks(())
plt.show()
if __name__ == '__main__':
lin = Linear()
lin.main()
好了我可以拟合我的三个点了。
# 线性拟合
from sklearn import linear_model
import numpy as np
reg = linear_model.LinearRegression()
X = np.array([14.36, 25.6, 50]).reshape(-1,1)
X_pre = np.array([3,30,55]).reshape(-1,1)
y = np.array([616, 1200, 2400]).reshape(-1,1)
reg.fit (X, y, sample_weight=None)
y_pre = reg.predict(X_pre)
# 绘图
import seaborn as sns
import matplotlib.pyplot as plt
from matplotlib.font_manager import FontManager, FontProperties
font = FontProperties(fname='/System/Library/Fonts/STHeiti Medium.ttc')
sns.set_style("whitegrid")
plt.figure(figsize=(8,4))
plt.scatter(X, y, color='black')
plt.plot(X_pre, y_pre, color='#B5D6FB', linewidth=3)
plt.axis([0,60,0,3000]) # 前两个数是x的范围,后两个y的范围
plt.xlabel("补贴(亿元)", FontProperties=font)
plt.ylabel("试点面积(万亩)", FontProperties=font)
#plt.title("", FontProperties=font)
plt.savefig('linear.jpg')