数据分析——交叉验证

使用cross_val_score可以做,learning_curve,validition_curve也可以。

from sklearn.datasets import load_iris
from sklearn.cross_validation import cross_val_score
from sklearn.neighbors import KNeighborsClassifier
import matplotlib.pyplot as plt
%matplotlib inline

iris = load_iris()
x_data = iris.data
y_data = iris.target

k_score = []
for k in range(1,31):
    knn = KNeighborsClassifier(n_neighbors=k)
    score = cross_val_score(knn,x_data,y_data,cv=10,scoring='accuracy')
    k_score.append(score.mean())

plt.figure()
plt.plot(range(1,31),k_score)

from sklearn.learning_curve import learning_curve
from sklearn.datasets import load_digits
from sklearn.svm import SVC

digits = load_digits()
svc = SVC()
x_data = digits.data
y_data = digits.target

train_size,train_loss,test_loss = learning_curve(SVC(gamma=0.001),x_data,y_data,cv=10,scoring='accuracy',train_sizes=[0.1,0.25,0.5,0.75,1])

train_loss_mean = train_loss.mean(axis=1)
test_loss_mean = test_loss.mean(axis=1)
plt.plot(train_size,-train_loss_mean,'r-o',label='train_loss')
plt.plot(train_size,-test_loss_mean,'g-o',label='test_loss')
plt.legend()

from sklearn.learning_curve import validation_curve
from sklearn.datasets import load_digits
from sklearn.svm import SVC
digits = load_digits()
x_data = digits.data
y_data = digits.target

train_loss,test_loss = validation_curve(SVC(),x_data,y_data,param_name='gamma',param_range=np.logspace(-6,-2,5),cv=10,scoring='accuracy')
train_loss_mean = train_loss.mean(axis=1)
test_loss_mean = test_loss.mean(axis=1)

plt.figure()
plt.plot(np.logspace(-6,-2,5),-train_loss_mean,'r-o',label='train_loss')
plt.plot(np.logspace(-6,-2,5),-test_loss_mean,'g-o',label='test_loss')
plt.legend()

猜你喜欢

转载自www.cnblogs.com/slowlyslowly/p/8856711.html