diabetes 一元线性回归

导入相应的模块

import matplotlib.pyplot as plt
import numpy as np
from sklearn import datasets, linear_model
from sklearn.metrics import mean_squared_error, r2_score

下载数据集

diabetes = datasets.load_diabetes()
diabetes.keys()#对数据集进行查看的方法
dict_keys(['data', 'target', 'DESCR', 'feature_names', 'data_filename', 'target_filename'])
diabetes.feature_names#显示数据集行标签
diabetes.data_filename#数据集位置
['age', 'sex', 'bmi', 'bp', 's1', 's2', 's3', 's4', 's5', 's6']

'F:\\Anaconda\\Anaconda\\lib\\site-packages\\sklearn\\datasets\\data\\diabetes_data.csv.gz'
diabetes_x = diabetes.data[:,np.newaxis,2]#取第三列的值,

切片

不清楚np.newaxis时,做了以下试验,详解见

https://docs.scipy.org/doc/numpy/reference/constants.html#numpy.newaxis

a = np.arange(10,19).reshape((3,3))#不知道多少列的情况下将切片后的元素列排
a
array([[10, 11, 12],
       [13, 14, 15],
       [16, 17, 18]])
a[:,2].reshape(3,1)
array([[12],
       [15],
       [18]])
a[:,np.newaxis,2]
array([[12],
       [15],
       [18]])
diabetes_y = diabetes.target

一维切片

b = [1,2,3,4,5,6,7]
b[-2:]#倒数第二个元素之后的所有元素
[6, 7]
b[:-2]#倒数第二个元素之前的所有元素
[1, 2, 3, 4, 5]

多维切片

a
array([[10, 11, 12],
       [13, 14, 15],
       [16, 17, 18]])
a[:2,:2]#逗号前面是行,后面是列,列行元素取的方法和一维相同
array([[10, 11],
       [13, 14]])
a[[0,2],:2]#当跨行取值时,要加中括号
array([[10, 11],
       [16, 17]])
a[[0,2],0:2]
array([[10, 11],
       [16, 17]])

划分训练集和测试集

from sklearn.model_selection import train_test_split
x_train,x_test,y_train,y_test=train_test_split(diabetes_x,diabetes_y,test_size=0.2)

regr = linear_model.LinearRegression() #使用线性回归模型

regr.fit(x_train,y_train)#fit()函数拟合
LinearRegression(copy_X=True, fit_intercept=True, n_jobs=None,
         normalize=False)
diabetes_y_pred = regr.predict(x_test)

print('cofficients:\n', regr.coef_)#输出回归系数
cofficients:
 [980.59601535]
print("mean squared error: %.2f" %mean_squared_error(y_test,diabetes_y_pred))
mean squared error: 4191.83
print("variance score : %.2f" % r2_score(y_test,diabetes_y_pred))
variance score : 0.27
plt.scatter(x_test,y_test,color = 'black')
plt.plot(x_test,diabetes_y_pred,color = 'blue',linewidth=3)
[<matplotlib.lines.Line2D at 0x2365243c080>]

在这里插入图片描述



```python

猜你喜欢

转载自blog.csdn.net/weixin_43332500/article/details/89068216
今日推荐