La implementación de sklearn de regresión logística

Directorio de artículos

1. Importe los módulos necesarios
2. Generar datos
3. Construcción de modelos
4. Formación de modelos
5. Predicción del modelo
6.modelo de regresión logística
7. Dibuja la curva de predicción.
8. Calcule la precisión del índice de evaluación.

Contenido del texto:

1. Importe los módulos necesarios

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

2. Generar datos

2.1 Definir la función de generación de datos
def create_data(data_num=100):
    np.random.seed(21)
    x1=np.random.normal(1,0.2,data_num)
    x2=np.random.normal(2,0.2,data_num)
    x=np.append(x1,x2)
    y=np.array([0]*data_num+[1]*data_num)
    return x,y
2.2 Generar datos
X,y=create_data(1000)
X #查看X的数据
array([0.98960715, 0.97776079, 1.20835936, ..., 1.84049108, 2.14936146,
       1.90338769])
y #查看y的数据
array([0, 0, 0, ..., 1, 1, 1])
2.3 Dividir el conjunto de entrenamiento y el conjunto de prueba
from sklearn.model_selection import train_test_split
X_train,X_test,y_train,y_test=train_test_split(
    X,y,test_size=0.3,random_state=16)
2.4 Dibujar un diagrama de dispersión de los datos del conjunto de entrenamiento
plt.scatter(X_train,y_train,color='blue',s=20)
plt.show()

Diagrama de dispersión del conjunto de entrenamiento

2.5 Dibujar un diagrama de dispersión de los datos del conjunto de prueba
plt.scatter(X_test,y_test,color='g',s=20)
plt.show()

Gráfico de dispersión de los datos del conjunto de prueba

3. Construcción de modelos

from sklearn.linear_model import LogisticRegression
model=LogisticRegression()

4. Formación de modelos

  • Entrenamiento del modelo de regresión lineal sklearn.linear_model.LogisticRegression.fit
  • Parámetros utilizados:
    —X: función de entrada, si la entrada está en formato np.array, la forma debe ser (n_sample, n_feature).
    -Y: Ingrese la etiqueta.
X_train=X_train.reshape(-1,1)
model.fit(X=X_train,y=y_train)
LogisticRegression() #上述两行代码运行的输出

5. Predicción del modelo

  • Haga predicciones en el conjunto de prueba
  • Modelo de predicción de regresión lineal: sklearn.linear_model.LogisticRegression.predict
  • Parámetros utilizados:
    —X: función de entrada, si la entrada está en formato np.array, la forma debe ser (n_sample, n_feature).
    -C: Resultado de la previsión.
X_test=X_test.reshape(-1,1)
y_test_pred=model.predict(X=X_test)# 默认阀值为0.5
y_test_pred_proba=model.predict_proba(X=X_test) # 可以自定义阀值,比如自定义阀值0.6
Tome el umbral para juzgar los dos resultados de clasificación de la probabilidad.
def thes_func(x):
    thes=0.6
    return 1 if x>thes else 0
y_test_pred_thes=list(map(thes_func,y_test_pred_proba[:,1]))

6. Ver el coeficiente wy el intercepto b del modelo de regresión logística

  • Coeficiente de regresión: sklearn.linear_model.LogisticRegression.coef_
  • Término de intercepción: sklearn.linear_model.LogisticRegression.intercep_
w,b=model.coef_[0],model.intercept_
print('Weight={0}bias={1}'.format(w,b))
Weight=[9.53805539]bias=[-14.3705638]# print的输出结果

7. Dibuja la curva de predicción.

  • función scipy.special.expit, también conocida como función sigmoidea logística, definición: expit (x) = 1 / (1 + ex)
  • Parámetros:
    -x: la entrada de la función sigmoidea, el requisito de entrada es el formato de matriz np.array.
    --Out: La salida de la función sigmoidea, devuelta en el formato de np.array, con la misma forma que la entrada x.
from scipy.special import expit
X_train=X_train.reshape(-1)
X_test=X_test.reshape(-1)
sigmoid=expit(np.sort(X_test)*model.coef_[0]+model.intercept_)
plt.plot(np.sort(X_test),sigmoid,color='g')
plt.scatter(X_test,y_test,color='r',label='test dataset')
plt.legend()
plt.show()

Inserte la descripción de la imagen aquí

8. Calcule la precisión del índice de evaluación

  • Error cuadrático medio: sklearn.metrics.accuracy_score
  • Parámetros utilizados:
    —y_true: ground_truth
    —y_pred: valor predicho.
    Devuelve:
    -pérdida: resultado del cálculo de precisión.
from sklearn.metrics import accuracy_score
acc=accuracy_score(y_true=y_test,y_pred=y_test_pred)
print('Accuracy:{}'.format(acc))
Accuracy:0.9916666666666667 # print输出的结果

Supongo que te gusta

Origin blog.csdn.net/weixin_42961082/article/details/113805473
Recomendado
Clasificación