线性回归预测

pandas库

Python Data Analysis Library 或 pandas 是基于NumPy 的一种工具,该工具是为了解决数据分析任务而创建的。Pandas 纳入了大量库和一些标准的数据模型,提供了高效地操作大型数据集所需的工具。pandas提供了大量能使我们快速便捷地处理数据的函数和方法。你很快就会发现,它是使Python成为强大而高效的数据分析环境的重要因素之一。
这个库主要是用于对数据的清洗与提取。

matplotlib库

画图库

sklearn

机器学习库

numpy

数值计算库

代码示例

import pandas as pd
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LinearRegression
from pprint import pprint
import numpy as np

path = 'Advertising.csv'

#pandas读入
data = pd.read_csv(path,header=0)
# x = data[['TV','Radio','Newspaper']]
x = data[['TV','Radio']]
y = data['Sales']
print(x)
print(y)

#将所有的特征数据画在一起
plt.figure(facecolor = 'w')#创建一个图形,背景色为白色
plt.plot(data['TV'],y,'ro',label='TV')
plt.plot(data['Radio'],y,'g^',label='Radio')
plt.plot(data['Newspaper'],y,'mv',label='Newspaper')
plt.legend(loc = 'lower right')
plt.xlabel(u'广告费',fontsize = 16)
plt.ylabel(u'销量额',fontsize = 16)
plt.title(u'广告花费与销量额对比',fontsize = 16)
plt.grid()
plt.show()
#单独画出每种的特征数据图
plt.figure(facecolor='w',figsize=(9,10))
plt.subplot(311)
plt.plot(data['TV'],y,'ro')
plt.title('TV')
plt.ylabel('Sales')
plt.grid()

plt.subplot(312)
plt.plot(data['Radio'],y,'g^')
plt.title('Radio')
plt.ylabel('Sales')
plt.grid()

plt.subplot(313)
plt.plot(data['Newspaper'],y,'mv')
plt.title('Newspaper')
plt.ylabel('Sales')
plt.grid()

plt.show()

x_train,x_test,y_train,y_test = train_test_split(x,y,train_size = 0.8,random_state = 1)
print(x_test)
print(x_train.shape,y_train.shape)
model = LinearRegression()
model.fit(x_train, y_train)
print(model)
print(model.coef_,model.intercept_)

order = y_test.argsort(axis=0)
y_test = y_test.values[order]
x_test = x_test.values[order,:]
y_hat = model.predict(x_test)
mse = np.average((y_hat - np.array(y_test))**2) 
rmse = np.sqrt(mse)
print('MSE = ',mse)
print('RMSE = ',rmse)
print('R2 = ',model.score(x_train, y_train))
print('R2 = ',model.score(x_test, y_test))

plt.figure(facecolor='w')
t = np.arange(len(x_test))
plt.plot(t,y_test,'r-',linewidth=2,label='real data')
plt.plot(t,y_hat,'g-',linewidth=2,label='predict data')
plt.legend(loc='upper right')
plt.title('predict sales of lineregression',fontsize=18)
plt.grid(b=True)
plt.show()

猜你喜欢

转载自blog.csdn.net/weixin_41811413/article/details/82830375