八、交叉验证和网格搜索
1.什么是交叉验证?
就是将拿到的训练数据,分成训练集和验证集,比如将一份数据分成4份,其中一份作为验证集。然后经过过4次测试,每次都更换不同的验证集。即得到4次模型的结果,取平均值作为最终结果。又称4折交叉验证。
2.为什么要做交叉验证?
交叉验证的目的:为了让被评估的模型更加准确可信。
问题:这个只是让被评估的模型更加准确可信,那么怎么选择或者调优参数呢?
3.什么是网格搜索
通常情况下,**有很多参数是需要手动指定的,这种叫超参数。**但是手动过程繁杂,所以需要对模型预设几种参数组合。每组超参数都采用交叉验证来评估。最后选出最优参数组合建立模型。
4.交叉验证(模型准确可信),网格搜素(模型调优)API:
- sklearn.model_selection.GridSearchCV(estimator,param_grid=None,cv=None)
- 对估计器的指定参数值进行详尽搜索
- estimator:估计器对象
- param_grid:估计器参数(dict) {‘n_neighbors’:[1,3,5]}
- cv:指定几折交叉验证
- fit:输入训练数据
- score:准确率
- 结果分析
- bestscore_:在交叉验证中验证的最好结果
- bestestimator:最好的参数模型
- cvresults:每次交叉验证后的验证集准确率结果和训练集准确率结果
5.鸢尾花案例增加K值调优
# 1.获取数据集
from sklearn.datasets import load_iris()
iris = load_iris()
# 2.数据基本处理--划分数据集
from sklearn.model_selection import train_test_split
x_train,x_test,y_train,y_test = train_test_split(iris.data,iris.target,random_state=1)
# 3.特征工程--标准化
from sklearn.preprocessing import StandardScale
transfer = StandardScale()
x_train = transfer.fit_transform(x_train)
x_test = transfer.transform(x_test)
# 4.KNN估计器
from sklearn.neighbors import KNeighborsClassifier
estimator = KNeighborsClassifier()
# 交叉验证和网络搜索
from sklearn.model_selection import GridSearchCV
param_dict = {'n_neighbors':[1,3,5]}
estimator = GridSerachCV(estimator,param_grid=param_dict,cv=3)
estimator.fit(x_train,y_train)
# 5.模型评估
# 方案一,对比真实值和预测值
y_predict = estimator.predict(x_test)
print("预测值为:",y_predict)
print("真实值和预测值的对比:",y_predict==y_test)
# 方案二,直接计算准确率
score = estimator.score(x_test,y_test)
print("准确率为:",score)
# 6.查看交叉验证和网络搜索的结果
print("交叉验证中验证的最好结果",estimator.bestscore_)
print("最好的参数模型",estimator.best_estimator_)
print("每次验证后的准确率结果",estimator.cv_results_)
九、案例3:预测facebook将要签到的位置
1.项目描述
本次比赛的目的是预测一个人将要签到的地方。 为了本次比赛,Facebook创建了一个虚拟世界,其中包括10公里*10公里共100平方公里的约10万个地方。 对于给定的坐标集,您的任务将根据用户的位置,准确性和时间戳等预测用户下一次的签到位置。 数据被制作成类似于来自移动设备的位置数据。 请注意:您只能使用提供的数据进行预测。
2.数据集介绍
文件说明 train.csv, test.csv
row id:签入事件的id
x y:坐标
accuracy: 准确度,定位精度
time: 时间戳
place_id: 签到的位置,这也是你需要预测的内容
3.步骤分析
-
对于数据做一些基本处理(这里所做的一些处理不一定达到很好的效果,我们只是简单尝试,有些特征我们可以根据一些特征选择的方式去做处理)
- 1 缩小数据集范围 DataFrame.query()
- 2 选取有用的时间特征
- 3 将签到位置少于n个用户的删除
-
分割数据集
-
标准化处理
-
k-近邻预测