参考 https://blog.csdn.net/u010900574/article/details/52666291
这篇真的是大神作品,对此从中吸取经验,我也来搞下机器学习的回归算法
基本回归方法(线性、决策树、SVM、KNN)和集成方法(随机森林,Adaboost和GBRT)
这是源码,有点参考大神的代码
import numpy as np
import matplotlib.pyplot as plt
def get_data():
x1_train = np.linspace(0,100,50).reshape(-1,1)
x2_train = np.linspace(0,50,50).reshape(-1,1)
y_train = 10*np.sin(x1_train)+ 10*np.cos(x2_train) +x1_train+x2_train +np.random.randn(1)+ 0.1
data_train = np.concatenate((x1_train,x2_train,y_train),axis=1)
x1_test = np.linspace(0,100,50).reshape(-1,1)
x2_test = np.linspace(0,50,50).reshape(-1,1)
y_test = 10*np.sin(x1_test)+ 10*np.cos(x2_test) +x1_test+x2_test +np.random.randn(1)+ 0.1
data_test = np.concatenate((x1_test,x2_test,y_test),axis=1)
return data_train,data_test
train ,test = get_data()
x_train, y_train = train[:,:2], train[:,2]
x_test ,y_test = test[:,:2], test[:,2]
def try_different_method(clf):
clf.fit(x_train,y_train)
score = clf.score(x_test, y_test)
result = clf.predict(x_test)
plt.figure(figsize=(10,8))
plt.plot(np.arange(len(result)), y_test,'ro-',label='true value')
plt.plot(np.arange(len(result)),result,'go-',label='predict value')
plt.title('score: %f'%score)
plt.legend()
plt.show()
# 线性回归
from sklearn.linear_model import LinearRegression
line = LinearRegression()
try_different_method(line)
# 岭回归
from sklearn.linear_model import Ridge
ridge= Ridge(random_state=0,alpha=0.001)
try_different_method(ridge)
# 回归树
from sklearn.tree import DecisionTreeRegressor
tree = DecisionTreeRegressor()
try_different_method(tree)
# KNN
from sklearn import neighbors
knn = neighbors.KNeighborsRegressor()
try_different_method(knn)
# 随机森林
from sklearn import ensemble
rf =ensemble.RandomForestRegressor(n_estimators=20)#这里使用20个决策树
try_different_method(rf)
# 集成学习
from sklearn import ensemble
ada = ensemble.AdaBoostRegressor(n_estimators=50)
try_different_method(ada)
# 梯度提升
from sklearn import ensemble
gbrt = ensemble.GradientBoostingRegressor(n_estimators=100)
try_different_method(gbrt)
3 SVR
from sklearn.svm import SVR
svr = SVR(kernel='rbf', C=1e3, gamma=0.1)
try_different_method(svr)
发现自己还是挺菜的