Use scikit-plot para visualizar modelos de aprendizaje automático entrenados (incluidas curvas ROC de varias clases, matrices de confusión, etc.)

Tabla de contenido

1. Instalación

2. Dibujo del caso

1) Visualización de indicadores de evaluación

1. Matriz de confusión

2. Curva ROC multicategoría

3. Gráfico estadístico de KS

4. Curva PR

5. Análisis de análisis de silueta.

6. La curva de corrección del clasificador

2) Visualización del modelo

1. Curvas de aprendizaje de entrenamiento y prueba bajo diferentes muestras de entrenamiento

2. La importancia de las características visuales

3) Visualización de grupos

1. Diagrama de codo de agrupamiento

4) Visualización de reducción de dimensionalidad

1. La razón de varianza explicada del componente PCA

2. Diagrama de dispersión después de la reducción de la dimensionalidad de PCA


scikit-learn (sklearn)Es una biblioteca común de aprendizaje automático en el entorno Python, que incluye algoritmos comunes de clasificación, regresión y agrupamiento. Después de entrenar el modelo, una operación común es visualizar el modelo, que debe mostrarse Matplotlibmediante .

scikit-plotEs una biblioteca basada en sklearny Matplotlib, la función principal es visualizar el modelo entrenado, la función es relativamente simple y fácil de entender.

1. Instalación

pip install scikit-plot -i https://pypi.tuna.tsinghua.edu.cn/simple

2. Dibujo del caso

1) Visualización de indicadores de evaluación

1. Matriz de confusión

import scikitplot as skplt
rf = RandomForestClassifier()
rf = rf.fit(X_train, y_train)
y_pred = rf.predict(X_test)

skplt.metrics.plot_confusion_matrix(y_test, y_pred, normalize=True)
plt.show()

2. Curva ROC multicategoría

import scikitplot as skplt
nb = GaussianNB()
nb = nb.fit(X_train, y_train)
y_probas = nb.predict_proba(X_test)

skplt.metrics.plot_roc(y_test, y_probas)
plt.show()

3. Gráfico estadístico de KS

import scikitplot as skplt
lr = LogisticRegression()
lr = lr.fit(X_train, y_train)
y_probas = lr.predict_proba(X_test)

skplt.metrics.plot_ks_statistic(y_test, y_probas)
plt.show()

4. Curva PR

import scikitplot as skplt
nb = GaussianNB()
nb.fit(X_train, y_train)
y_probas = nb.predict_proba(X_test)

skplt.metrics.plot_precision_recall(y_test, y_probas)
plt.show()

5. Análisis de análisis de silueta.

import scikitplot as skplt
kmeans = KMeans(n_clusters=4, random_state=1)
cluster_labels = kmeans.fit_predict(X)

skplt.metrics.plot_silhouette(X, cluster_labels)
plt.show()

6. La curva de corrección del clasificador

import scikitplot as skplt
rf = RandomForestClassifier()
lr = LogisticRegression()
nb = GaussianNB()
svm = LinearSVC()
rf_probas = rf.fit(X_train, y_train).predict_proba(X_test)
lr_probas = lr.fit(X_train, y_train).predict_proba(X_test)
nb_probas = nb.fit(X_train, y_train).predict_proba(X_test)
svm_scores = svm.fit(X_train, y_train).decision_function(X_test)
probas_list = [rf_probas, lr_probas, nb_probas, svm_scores]
clf_names = ['Random Forest', 'Logistic Regression',
              'Gaussian Naive Bayes', 'Support Vector Machine']

skplt.metrics.plot_calibration_curve(y_test,
                                      probas_list,
                                      clf_names)
plt.show()

2) Visualización del modelo

1. Curvas de aprendizaje de entrenamiento y prueba bajo diferentes muestras de entrenamiento

import scikitplot as skplt
rf = RandomForestClassifier()

skplt.estimators.plot_learning_curve(rf, X, y)
plt.show()

2. La importancia de las características visuales

import scikitplot as skplt
rf = RandomForestClassifier()
rf.fit(X, y)

skplt.estimators.plot_feature_importances(
     rf, feature_names=['petal length', 'petal width',
                        'sepal length', 'sepal width'])
plt.show()

3) Visualización de grupos

1. Diagrama de codo de agrupamiento

import scikitplot as skplt
kmeans = KMeans(random_state=1)

skplt.cluster.plot_elbow_curve(kmeans, cluster_ranges=range(1, 30))
plt.show()

4) Visualización de reducción de dimensionalidad

1. La razón de varianza explicada del componente PCA

import scikitplot as skplt
pca = PCA(random_state=1)
pca.fit(X)

skplt.decomposition.plot_pca_component_variance(pca)
>plt.show()

2. Diagrama de dispersión después de la reducción de la dimensionalidad de PCA

import scikitplot as skplt
pca = PCA(random_state=1)
pca.fit(X)

skplt.decomposition.plot_pca_2d_projection(pca, X, y)
plt.show()

Supongo que te gusta

Origin blog.csdn.net/qq_45100200/article/details/131268560
Recomendado
Clasificación