机器学习-交叉验证和网格搜索(带案例)

八、交叉验证和网格搜索

1.什么是交叉验证?

就是将拿到的训练数据,分成训练集和验证集,比如将一份数据分成4份,其中一份作为验证集。然后经过过4次测试,每次都更换不同的验证集。即得到4次模型的结果,取平均值作为最终结果。又称4折交叉验证。

2.为什么要做交叉验证?

交叉验证的目的:为了让被评估的模型更加准确可信。

问题:这个只是让被评估的模型更加准确可信,那么怎么选择或者调优参数呢?

3.什么是网格搜索

通常情况下,**有很多参数是需要手动指定的,这种叫超参数。**但是手动过程繁杂,所以需要对模型预设几种参数组合。每组超参数都采用交叉验证来评估。最后选出最优参数组合建立模型。

4.交叉验证(模型准确可信),网格搜素(模型调优)API:
  • sklearn.model_selection.GridSearchCV(estimator,param_grid=None,cv=None)
    • 对估计器的指定参数值进行详尽搜索
    • estimator:估计器对象
    • param_grid:估计器参数(dict) {‘n_neighbors’:[1,3,5]}
    • cv:指定几折交叉验证
    • fit:输入训练数据
    • score:准确率
    • 结果分析
      • bestscore_:在交叉验证中验证的最好结果
      • bestestimator:最好的参数模型
      • cvresults:每次交叉验证后的验证集准确率结果和训练集准确率结果
5.鸢尾花案例增加K值调优
# 1.获取数据集
from sklearn.datasets import load_iris()
iris = load_iris()

# 2.数据基本处理--划分数据集
from sklearn.model_selection import train_test_split
x_train,x_test,y_train,y_test = train_test_split(iris.data,iris.target,random_state=1)

# 3.特征工程--标准化
from sklearn.preprocessing import StandardScale
transfer = StandardScale()
x_train = transfer.fit_transform(x_train)
x_test = transfer.transform(x_test)

# 4.KNN估计器
from sklearn.neighbors import KNeighborsClassifier
estimator = KNeighborsClassifier()
# 交叉验证和网络搜索
from sklearn.model_selection import GridSearchCV
param_dict = {'n_neighbors':[1,3,5]}
estimator = GridSerachCV(estimator,param_grid=param_dict,cv=3)
estimator.fit(x_train,y_train)

# 5.模型评估
# 方案一,对比真实值和预测值
y_predict = estimator.predict(x_test)
print("预测值为:",y_predict)
print("真实值和预测值的对比:",y_predict==y_test)
# 方案二,直接计算准确率
score = estimator.score(x_test,y_test)
print("准确率为:",score)

# 6.查看交叉验证和网络搜索的结果
print("交叉验证中验证的最好结果",estimator.bestscore_)
print("最好的参数模型",estimator.best_estimator_)
print("每次验证后的准确率结果",estimator.cv_results_)

九、案例3:预测facebook将要签到的位置

1.项目描述

本次比赛的目的是预测一个人将要签到的地方。 为了本次比赛,Facebook创建了一个虚拟世界,其中包括10公里*10公里共100平方公里的约10万个地方。 对于给定的坐标集,您的任务将根据用户的位置,准确性和时间戳等预测用户下一次的签到位置。 数据被制作成类似于来自移动设备的位置数据。 请注意:您只能使用提供的数据进行预测。

2.数据集介绍
文件说明 train.csv, test.csv
  row id:签入事件的id
  x y:坐标
  accuracy: 准确度,定位精度
  time: 时间戳
  place_id: 签到的位置,这也是你需要预测的内容
3.步骤分析
  • 对于数据做一些基本处理(这里所做的一些处理不一定达到很好的效果,我们只是简单尝试,有些特征我们可以根据一些特征选择的方式去做处理)

    • 1 缩小数据集范围 DataFrame.query()
    • 2 选取有用的时间特征
    • 3 将签到位置少于n个用户的删除
  • 分割数据集

  • 标准化处理

  • k-近邻预测

发布了104 篇原创文章 · 获赞 33 · 访问量 3万+

猜你喜欢

转载自blog.csdn.net/WangTaoTao_/article/details/103065002