机器学习分类模型的选择

分别用逻辑回归、线性回归、K近邻、决策树、贝叶斯和SVM6个算法对iris数据集进行分类,并采用交叉验证计算模型的准确率。

加载一些库:

from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.model_selection import KFold
from sklearn.model_selection import cross_val_score
from sklearn.metrics import classification_report
from sklearn.metrics import confusion_matrix
from sklearn.metrics import accuracy_score
from sklearn.linear_model import LogisticRegression
from sklearn.tree import DecisionTreeClassifier
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
from sklearn.neighbors import KNeighborsClassifier
from sklearn.naive_bayes import GaussianNB
from sklearn.svm import SVC
from matplotlib import pyplot

加载数据集:

#加载数据集
data = load_iris()
X = data['data']
y = data['target']

训练集和测试集的划分:

#训练集和测试集的划分
validation_size = 0.2
seed = 7
X_train,X_validation,Y_train,Y_validation = train_test_split(X,y,test_size = validation_size,random_state = seed)

创建模型:

models = {}
models['LR'] = LogisticRegression()
models['LDA'] = LinearDiscriminantAnalysis()
models['KNN'] = KNeighborsClassifier()
models['CART'] = DecisionTreeClassifier()
models['NB'] = GaussianNB()
models['SVM'] = SVC()

训练模型并采用交叉验证检验模型的准确率,取均值:

results = []
for key in models:
    kfold = KFold(n_splits=10,random_state=seed)
    cv_results = cross_val_score(models[key],X_train,Y_train,cv=kfold,scoring='accuracy')
    results.append(cv_results)
    print('%s: %f (%f)'%(key,cv_results.mean(),cv_results.std()))

看到交叉验证下的SVM模型准确率最高,所以在多个机器学习算法中选择SVM算法进行建模预测。先对交叉验证下的各个模型的准确率进行可视化:

#直方图/箱线图比较算法
fig = pyplot.figure()
fig.suptitle('Algorithm comparison')
ax = fig.add_subplot(111)
pyplot.boxplot(results)
#pyplot.hist(results)
ax.set_xticklabels(models.keys())
pyplot.show()

选择SVM算法进行建模预测:

#实施预测
svm = SVC()
svm.fit(X=X_train,y=Y_train)#X是大写,y是小写
predictions = svm.predict(X_validation)
print(accuracy_score(Y_validation,predictions))#精准度
print(confusion_matrix(Y_validation,predictions))#冲突矩阵
print(classification_report(Y_validation,predictions))#精确率、召回率、F1值

猜你喜欢

转载自blog.csdn.net/spartanfuk/article/details/81490098
今日推荐