机器学习之路:python 网格搜索 GridSearchCV 模型检验方法

git:https://github.com/linyi0604/MachineLearning

如何确定一个模型应该使用哪种参数?

k折交叉验证:
将样本分成k份
每次取其中一份做测试数据 其他做训练数据
一共进行k次训练和测试
用这种方式 充分利用样本数据,评估模型在样本上的表现情况


网格搜索:
一种暴力枚举搜索方法
对模型参数列举出集中可能,
对所有列举出的可能组合进行模型评估
从而找到最好的模型参数


python实现的代码:
  1 from sklearn.datasets import fetch_20newsgroups
  2 from sklearn.cross_validation import train_test_split
  3 import numpy as np
  4 from sklearn.svm import SVC
  5 from sklearn.feature_extraction.text import TfidfVectorizer
  6 from sklearn.pipeline import Pipeline
  7 from sklearn.grid_search import GridSearchCV
  8 
  9 '''
 10 如何确定一个模型应该使用哪种参数?
 11 
 12 k折交叉验证:
 13    将样本分成k份
 14    每次取其中一份做测试数据 其他做训练数据 
 15    一共进行k次训练和测试
 16    用这种方式 充分利用样本数据,评估模型在样本上的表现情况
 17    
 18    
 19 网格搜索:
 20     一种暴力枚举搜索方法
 21     对模型参数列举出集中可能,
 22     对所有列举出的可能组合进行模型评估
 23     从而找到最好的模型参数
 24 
 25 '''
 26 
 27 # 联网获取所有想你问数据
 28 news = fetch_20newsgroups(subset="all")
 29 # 分割训练数据和测试数据
 30 x_train, x_test, y_train, y_test = train_test_split(news.data[:3000],
 31                                                     news.target[:3000],
 32                                                     test_size=0.25,
 33                                                     random_state=33)
 34 
 35 # 使用pipeline简化系统搭建流程
 36 clf = Pipeline([("vect", TfidfVectorizer(stop_words="english", analyzer="word")), ("svc", SVC())])
 37 
 38 # 这里要实验的超参数有两个  4个svg__gama 和 3个svg__C 一共12种组合
 39 # np.logspace(start, end, num) 从10^start 到 10^end 创建num个数的等比数列
 40 parameters = {"svc__gamma": np.logspace(-2, 1, 4), "svc__C": np.logspace(-1, 1, 3)}
 41 
 42 # 网格搜索
 43 # 创建一个网格搜索: 12组参数组合, 3折交叉验证
 44 gs = GridSearchCV(clf, parameters, verbose=2, refit=True, cv=3)
 45 
 46 # 执行单线程网格搜索
 47 time_ = gs.fit(x_train, y_train)
 48 print(time_)
 49 print(gs.best_params_, gs.best_score_)
 50 # 输出最佳模型在测试机和上的准确性
 51 print(gs.score(x_test, y_test))
 52 '''
 53 Fitting 3 folds for each of 12 candidates, totalling 36 fits
 54 [CV] svc__C=0.1, svc__gamma=0.01 .....................................
 55 [CV] ............................ svc__C=0.1, svc__gamma=0.01 -   8.3s
 56 [Parallel(n_jobs=1)]: Done   1 out of   1 | elapsed:    8.3s remaining:    0.0s
 57 [CV] svc__C=0.1, svc__gamma=0.01 .....................................
 58 [CV] ............................ svc__C=0.1, svc__gamma=0.01 -   8.5s
 59 [CV] svc__C=0.1, svc__gamma=0.01 .....................................
 60 [CV] ............................ svc__C=0.1, svc__gamma=0.01 -   8.5s
 61 [CV] svc__C=0.1, svc__gamma=0.1 ......................................
 62 [CV] ............................. svc__C=0.1, svc__gamma=0.1 -   8.4s
 63 [CV] svc__C=0.1, svc__gamma=0.1 ......................................
 64 [CV] ............................. svc__C=0.1, svc__gamma=0.1 -   8.5s
 65 [CV] svc__C=0.1, svc__gamma=0.1 ......................................
 66 [CV] ............................. svc__C=0.1, svc__gamma=0.1 -   8.5s
 67 [CV] svc__C=0.1, svc__gamma=1.0 ......................................
 68 [CV] ............................. svc__C=0.1, svc__gamma=1.0 -   8.4s
 69 [CV] svc__C=0.1, svc__gamma=1.0 ......................................
 70 [CV] ............................. svc__C=0.1, svc__gamma=1.0 -   8.6s
 71 [CV] svc__C=0.1, svc__gamma=1.0 ......................................
 72 [CV] ............................. svc__C=0.1, svc__gamma=1.0 -   8.6s
 73 [CV] svc__C=0.1, svc__gamma=10.0 .....................................
 74 [CV] ............................ svc__C=0.1, svc__gamma=10.0 -   8.5s
 75 [CV] svc__C=0.1, svc__gamma=10.0 .....................................
 76 [CV] ............................ svc__C=0.1, svc__gamma=10.0 -   8.6s
 77 [CV] svc__C=0.1, svc__gamma=10.0 .....................................
 78 [CV] ............................ svc__C=0.1, svc__gamma=10.0 -   8.7s
 79 [CV] svc__C=1.0, svc__gamma=0.01 .....................................
 80 [CV] ............................ svc__C=1.0, svc__gamma=0.01 -   8.3s
 81 [CV] svc__C=1.0, svc__gamma=0.01 .....................................
 82 [CV] ............................ svc__C=1.0, svc__gamma=0.01 -   8.4s
 83 [CV] svc__C=1.0, svc__gamma=0.01 .....................................
 84 [CV] ............................ svc__C=1.0, svc__gamma=0.01 -   8.5s
 85 [CV] svc__C=1.0, svc__gamma=0.1 ......................................
 86 [CV] ............................. svc__C=1.0, svc__gamma=0.1 -   8.3s
 87 [CV] svc__C=1.0, svc__gamma=0.1 ......................................
 88 [CV] ............................. svc__C=1.0, svc__gamma=0.1 -   8.4s
 89 [CV] svc__C=1.0, svc__gamma=0.1 ......................................
 90 [CV] ............................. svc__C=1.0, svc__gamma=0.1 -   8.5s
 91 [CV] svc__C=1.0, svc__gamma=1.0 ......................................
 92 [CV] ............................. svc__C=1.0, svc__gamma=1.0 -   8.5s
 93 [CV] svc__C=1.0, svc__gamma=1.0 ......................................
 94 [CV] ............................. svc__C=1.0, svc__gamma=1.0 -   8.6s
 95 [CV] svc__C=1.0, svc__gamma=1.0 ......................................
 96 [CV] ............................. svc__C=1.0, svc__gamma=1.0 -   8.7s
 97 [CV] svc__C=1.0, svc__gamma=10.0 .....................................
 98 [CV] ............................ svc__C=1.0, svc__gamma=10.0 -   8.5s
 99 [CV] svc__C=1.0, svc__gamma=10.0 .....................................
100 [CV] ............................ svc__C=1.0, svc__gamma=10.0 -   8.6s
101 [CV] svc__C=1.0, svc__gamma=10.0 .....................................
102 [CV] ............................ svc__C=1.0, svc__gamma=10.0 -   8.7s
103 [CV] svc__C=10.0, svc__gamma=0.01 ....................................
104 [CV] ........................... svc__C=10.0, svc__gamma=0.01 -   8.4s
105 [CV] svc__C=10.0, svc__gamma=0.01 ....................................
106 [CV] ........................... svc__C=10.0, svc__gamma=0.01 -   8.4s
107 [CV] svc__C=10.0, svc__gamma=0.01 ....................................
108 [CV] ........................... svc__C=10.0, svc__gamma=0.01 -   8.7s
109 [CV] svc__C=10.0, svc__gamma=0.1 .....................................
110 [CV] ............................ svc__C=10.0, svc__gamma=0.1 -   8.6s
111 [CV] svc__C=10.0, svc__gamma=0.1 .....................................
112 [CV] ............................ svc__C=10.0, svc__gamma=0.1 -   8.6s
113 [CV] svc__C=10.0, svc__gamma=0.1 .....................................
114 [CV] ............................ svc__C=10.0, svc__gamma=0.1 -   8.6s
115 [CV] svc__C=10.0, svc__gamma=1.0 .....................................
116 [CV] ............................ svc__C=10.0, svc__gamma=1.0 -   8.5s
117 [CV] svc__C=10.0, svc__gamma=1.0 .....................................
118 [CV] ............................ svc__C=10.0, svc__gamma=1.0 -   8.6s
119 [CV] svc__C=10.0, svc__gamma=1.0 .....................................
120 [CV] ............................ svc__C=10.0, svc__gamma=1.0 -   9.3s
121 [CV] svc__C=10.0, svc__gamma=10.0 ....................................
122 [CV] ........................... svc__C=10.0, svc__gamma=10.0 -   8.8s
123 [CV] svc__C=10.0, svc__gamma=10.0 ....................................
124 [CV] ........................... svc__C=10.0, svc__gamma=10.0 -   8.9s
125 [CV] svc__C=10.0, svc__gamma=10.0 ....................................
126 [CV] ........................... svc__C=10.0, svc__gamma=10.0 -   8.7s
127 
128 12组超参数 3折交叉验证 共36个搜索项 花费5.2分钟
129 [Parallel(n_jobs=1)]: Done  36 out of  36 | elapsed:  5.2min finished
130 
131 最佳参数   最佳训练得分
132 {'svc__C': 10.0, 'svc__gamma': 0.1} 0.7906666666666666
133 最佳模型的测试得分
134 0.8226666666666667
135 
136 '''

猜你喜欢

转载自www.cnblogs.com/Lin-Yi/p/9000989.html