Conceptos básicos del aprendizaje automático 17: explicación de todo el proceso de entrenamiento de un modelo basado en el conjunto de datos de Boston House Price


El aprendizaje automático es una habilidad experiencial y la práctica es una de las formas efectivas de dominar el aprendizaje automático y mejorar la capacidad de utilizar el aprendizaje automático para resolver problemas. Entonces, ¿cómo se puede utilizar el aprendizaje automático para resolver problemas?
Esta sección presentará un problema de regresión paso a paso mediante un ejemplo.
Este capítulo presenta principalmente los siguientes contenidos:

  • Cómo completar un modelo de un problema de regresión de principio a fin.
  • Cómo mejorar la precisión del modelo mediante la transformación de datos.
  • Cómo mejorar la precisión del modelo ajustando parámetros.
  • Cómo mejorar la precisión del modelo mediante algoritmos de conjunto.

1 Definir el problema

En este proyecto, analizaremos y estudiaremos el conjunto de datos sobre precios de la vivienda en Boston . Cada fila de datos de este conjunto de datos describe los precios de la vivienda en Boston o en las ciudades. Los datos fueron recopilados estadísticamente en 1978. Los datos contienen las siguientes 14 características y 506 datos (como se define en UCI Machine Learning Warehouse).

· CRIM: tasa de criminalidad urbana per cápita.
· ZN: Proporción de suelo residencial.
· INDUS: Proporción de suelo no residencial en ciudades y pueblos.
· CHAS: Variable ficticia CHAS, utilizada para análisis de regresión.
· NOX: índice ambiental.
· RM: número de habitaciones por vivienda.
· EDAD: Proporción de unidades ocupadas por sus propietarios construidas antes de 1940.
· DIS: Distancia ponderada a 5 centros de empleo de Boston.
· RAD: índice de conveniencia de la distancia a autopistas.
· IMPUESTO: Tasa del impuesto sobre bienes inmuebles por $10,000.
· PRTATIO: Ratio profesor-alumnos en una localidad.
· B: Proporción de negros en la localidad.
· LSTAT: Cuántos propietarios de la zona son de bajos ingresos.
· MEDV: Precio medio de la vivienda para propietarios-ocupantes.

A través de la descripción de estos atributos de características, podemos encontrar que las unidades de medida de los atributos de características de entrada no son uniformes y puede ser necesario ajustar la unidad de medida de los datos.

2 Importar datos

Primero importe las bibliotecas de clases necesarias en el proyecto. El código se muestra a continuación:

import pandas as pd
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
from sklearn.linear_model import LogisticRegression

from sklearn.model_selection import  KFold, cross_val_score
from sklearn.naive_bayes import GaussianNB
from sklearn.neighbors import KNeighborsClassifier
from sklearn.svm import SVC
from sklearn.tree import DecisionTreeClassifier

A continuación, importe el conjunto de datos a Python. Este conjunto de datos también se puede descargar desde el repositorio de aprendizaje automático de UCI. Al importar el conjunto de datos, también se establece el nombre de la característica del atributo de datos.

El código se muestra a continuación:

#导入数据
path = 'D:\down\\BostonHousing.csv'
data = pd.read_csv(path)

3. Comprender los datos

Analizar datos importados para construir modelos apropiados. Primero, observe las dimensiones de los datos, como cuántos registros hay en el conjunto de datos y cuántas características de los datos hay.

El código se muestra a continuación:

print('data.shape=',data.shape)

Luego de la ejecución, podemos ver que hay un total de 506 registros y 14 atributos de características, lo cual es consistente con la información proporcionada por UCI.

data.shape= (506, 14)

Luego verifique los tipos de campo de cada atributo de característica . El código se muestra a continuación:

#特征属性字段类型
print(data.dtypes)

Se puede ver que todos los atributos de las características son números, la mayoría de los atributos de las características son
números de punto flotante y algunos atributos de las características son tipos enteros. Los resultados de la ejecución son los siguientes:

crim       float64
zn         float64
indus      float64
chas         int64
nox        float64
rm         float64
age        float64
dis        float64
rad          int64
tax          int64
ptratio    float64
b          float64
lstat      float64
medv       float64
dtype: object

A continuación, haremos una revisión simple de los datos, aquí miramos los primeros 30 registros. El código se muestra a continuación:

print(data.head(30))

Los resultados de la ejecución son los siguientes:

       crim    zn  indus  chas    nox  ...  tax  ptratio       b  lstat  medv
0   0.00632  18.0   2.31     0  0.538  ...  296     15.3  396.90   4.98  24.0
1   0.02731   0.0   7.07     0  0.469  ...  242     17.8  396.90   9.14  21.6
2   0.02729   0.0   7.07     0  0.469  ...  242     17.8  392.83   4.03  34.7
3   0.03237   0.0   2.18     0  0.458  ...  222     18.7  394.63   2.94  33.4
4   0.06905   0.0   2.18     0  0.458  ...  222     18.7  396.90   5.33  36.2
5   0.02985   0.0   2.18     0  0.458  ...  222     18.7  394.12   5.21  28.7
6   0.08829  12.5   7.87     0  0.524  ...  311     15.2  395.60  12.43  22.9
7   0.14455  12.5   7.87     0  0.524  ...  311     15.2  396.90  19.15  27.1
8   0.21124  12.5   7.87     0  0.524  ...  311     15.2  386.63  29.93  16.5
9   0.17004  12.5   7.87     0  0.524  ...  311     15.2  386.71  17.10  18.9
10  0.22489  12.5   7.87     0  0.524  ...  311     15.2  392.52  20.45  15.0
11  0.11747  12.5   7.87     0  0.524  ...  311     15.2  396.90  13.27  18.9
12  0.09378  12.5   7.87     0  0.524  ...  311     15.2  390.50  15.71  21.7
13  0.62976   0.0   8.14     0  0.538  ...  307     21.0  396.90   8.26  20.4
14  0.63796   0.0   8.14     0  0.538  ...  307     21.0  380.02  10.26  18.2
15  0.62739   0.0   8.14     0  0.538  ...  307     21.0  395.62   8.47  19.9
16  1.05393   0.0   8.14     0  0.538  ...  307     21.0  386.85   6.58  23.1
17  0.78420   0.0   8.14     0  0.538  ...  307     21.0  386.75  14.67  17.5
18  0.80271   0.0   8.14     0  0.538  ...  307     21.0  288.99  11.69  20.2
19  0.72580   0.0   8.14     0  0.538  ...  307     21.0  390.95  11.28  18.2
20  1.25179   0.0   8.14     0  0.538  ...  307     21.0  376.57  21.02  13.6
21  0.85204   0.0   8.14     0  0.538  ...  307     21.0  392.53  13.83  19.6
22  1.23247   0.0   8.14     0  0.538  ...  307     21.0  396.90  18.72  15.2
23  0.98843   0.0   8.14     0  0.538  ...  307     21.0  394.54  19.88  14.5
24  0.75026   0.0   8.14     0  0.538  ...  307     21.0  394.33  16.30  15.6
25  0.84054   0.0   8.14     0  0.538  ...  307     21.0  303.42  16.51  13.9
26  0.67191   0.0   8.14     0  0.538  ...  307     21.0  376.88  14.81  16.6
27  0.95577   0.0   8.14     0  0.538  ...  307     21.0  306.38  17.28  14.8
28  0.77299   0.0   8.14     0  0.538  ...  307     21.0  387.94  12.80  18.4
29  1.00245   0.0   8.14     0  0.538  ...  307     21.0  380.23  11.98  21.0

A continuación veamos las estadísticas descriptivas de los datos. El código se muestra a continuación:

#pandas 新版本
pd.options.display.precision=1
#pandas老版本
#pd.set_option("precision", 1)

La información estadística descriptiva incluye el valor máximo, el valor mínimo, el valor mediano, el valor cuartil,
etc. de los datos. El análisis de estos datos puede profundizar la comprensión de la distribución de los datos, la estructura de los datos, etc. Los resultados son los siguientes

          crim     zn  indus     chas  ...  ptratio      b  lstat   medv
count  5.1e+02  506.0  506.0  5.1e+02  ...    506.0  506.0  506.0  506.0
mean   3.6e+00   11.4   11.1  6.9e-02  ...     18.5  356.7   12.7   22.5
std    8.6e+00   23.3    6.9  2.5e-01  ...      2.2   91.3    7.1    9.2
min    6.3e-03    0.0    0.5  0.0e+00  ...     12.6    0.3    1.7    5.0
25%    8.2e-02    0.0    5.2  0.0e+00  ...     17.4  375.4    6.9   17.0
50%    2.6e-01    0.0    9.7  0.0e+00  ...     19.1  391.4   11.4   21.2
75%    3.7e+00   12.5   18.1  0.0e+00  ...     20.2  396.2   17.0   25.0
max    8.9e+01  100.0   27.7  1.0e+00  ...     22.0  396.9   38.0   50.0

A continuación, echemos un vistazo a la correlación por pares entre las características de los datos, aquí observamos el coeficiente de correlación de Pearson de los datos. El código se muestra a continuación:

         crim    zn  indus      chas   nox  ...   tax  ptratio     b  lstat  medv
crim     1.00 -0.20   0.41 -5.59e-02  0.42  ...  0.58     0.29 -0.39   0.46 -0.39
zn      -0.20  1.00  -0.53 -4.27e-02 -0.52  ... -0.31    -0.39  0.18  -0.41  0.36
indus    0.41 -0.53   1.00  6.29e-02  0.76  ...  0.72     0.38 -0.36   0.60 -0.48
chas    -0.06 -0.04   0.06  1.00e+00  0.09  ... -0.04    -0.12  0.05  -0.05  0.18
nox      0.42 -0.52   0.76  9.12e-02  1.00  ...  0.67     0.19 -0.38   0.59 -0.43
rm      -0.22  0.31  -0.39  9.13e-02 -0.30  ... -0.29    -0.36  0.13  -0.61  0.70
age      0.35 -0.57   0.64  8.65e-02  0.73  ...  0.51     0.26 -0.27   0.60 -0.38
dis     -0.38  0.66  -0.71 -9.92e-02 -0.77  ... -0.53    -0.23  0.29  -0.50  0.25
rad      0.63 -0.31   0.60 -7.37e-03  0.61  ...  0.91     0.46 -0.44   0.49 -0.38
tax      0.58 -0.31   0.72 -3.56e-02  0.67  ...  1.00     0.46 -0.44   0.54 -0.47
ptratio  0.29 -0.39   0.38 -1.22e-01  0.19  ...  0.46     1.00 -0.18   0.37 -0.51
b       -0.39  0.18  -0.36  4.88e-02 -0.38  ... -0.44    -0.18  1.00  -0.37  0.33
lstat    0.46 -0.41   0.60 -5.39e-02  0.59  ...  0.54     0.37 -0.37   1.00 -0.74
medv    -0.39  0.36  -0.48  1.75e-01 -0.43  ... -0.47    -0.51  0.33  -0.74  1.00

[14 rows x 14 columns]

De los resultados anteriores, podemos ver que algunos atributos de características tienen fuertes correlaciones (>0,7 o <-0,7), como por ejemplo:
· El coeficiente de correlación de Pearson entre NOX e INDUS es 0,76.
· El coeficiente de correlación de Pearson entre DIS e INDUS es -0,71.
· El coeficiente de correlación de Pearson entre IMPUESTO e INDUS es 0,72.
· El coeficiente de correlación de Pearson entre AGE y NOX es 0,73.
· El coeficiente de correlación de Pearson entre DIS y NOX es -0,77.


4 Visualización de datos

gráfico de característica única

Primero mire los diagramas de distribución individuales para cada característica de datos. Mirar varios diagramas diferentes le ayudará a descubrir mejores métodos. Podemos sentir la distribución de los datos al observar el histograma de cada característica de datos. El código se muestra a continuación:

data.hist(sharex=False,sharey=False,xlabelsize=1,ylabelsize=1)
pyplot.show()

Los resultados de la ejecución se muestran en la siguiente figura. En la figura, puede ver que algunos datos están distribuidos exponencialmente, como
CRIM, ZN, AGE y B, y algunas características de los datos son bimodales, como RAD e TAX.

Insertar descripción de la imagen aquí

Los atributos característicos de estos datos se pueden mostrar a través de gráficos de densidad . Los gráficos de densidad muestran estas características de datos de manera más fluida que los histogramas. El código se muestra a continuación:

data.plot(kind='density',subplots=True,layout=(4,4),sharex=False,fontsize=1)
pyplot.show()

En el gráfico de densidad, especifique diseño = (4, 4), lo que significa que se va a dibujar un gráfico con cuatro filas y cuatro columnas
. Los resultados de la ejecución se muestran en la figura. Puede verificar el estado de cada característica de datos
Insertar descripción de la imagen aquí
a través del diagrama de caja y también puede ver fácilmente el grado de asimetría de la distribución de datos . El código se muestra a continuación:

data.plot(kind='box',subplots=True,layout=(4,4),sharex=False,fontsize=8)
pyplot.show()

Resultados de:

Insertar descripción de la imagen aquí

Múltiples gráficos de datos

A continuación, utilice varios gráficos de datos para ver la interacción entre diferentes características de datos. Primero eche un vistazo al diagrama de matriz de dispersión. El código se muestra a continuación:

#散点矩阵图
scatter_matrix(data)
pyplot.show()

Se puede ver en el gráfico de matriz de dispersión que, aunque la correlación entre algunas características de los datos es muy fuerte, la estructura de distribución de estos datos también es muy buena. Incluso si no es una estructura de distribución lineal, es una estructura de distribución que se puede predecir fácilmente. Los resultados de la ejecución se muestran en la figura.
Insertar descripción de la imagen aquí

Eche otro vistazo al diagrama de matriz de correlación de la influencia mutua de los datos. El código se muestra a continuación:

#相关矩阵图
names = ['crim', 'zn', 'indus', 'chas', 'nox', 'rm', 'age', 'dis', 'rad', 'tax','ptratio', 'b', 'lstat']
fig = pyplot.figure()
ax = fig.add_subplot(111)
cax = ax.matshow(data.corr(), vmin =-1,vmax =1, interpolation='none')
fig.colorbar(cax)
ticks = np.arange(0,13,1)
ax.set_xticks(ticks)
ax.set_yticks(ticks)
ax.set_xticklabels(names)
ax.set_yticklabels(names)
pyplot.show()


Los resultados de la ejecución se muestran en la figura. Según la leyenda, podemos ver que existen correlaciones por pares entre los atributos de características de datos . Algunos atributos están fuertemente correlacionados. Se recomienda eliminar estos atributos de características en el procesamiento posterior para mejorar la precisión de la algoritmo Gastar.

Insertar descripción de la imagen aquí

A través de la correlación y distribución de datos, se descubre que la estructura de datos en el conjunto de datos es relativamente compleja y es necesario considerar la conversión de datos para mejorar la precisión del modelo. Puedes intentar procesar los datos desde los siguientes aspectos:
· Reducir la mayoría de las características altamente relevantes mediante la selección de características.
· Reducir el impacto de diferentes unidades de medida de datos mediante la estandarización de datos.
· Reducir diferentes estructuras de distribución de datos normalizando los datos para mejorar la precisión del algoritmo.

Puede observar más a fondo la clasificación de probabilidad (discretización) de los datos, lo que puede ayudar a mejorar la precisión del algoritmo del árbol de decisión.


5. Conjuntos de datos de evaluación separados

Es una buena idea separar un conjunto de datos de evaluación para garantizar que el conjunto de datos separado esté completamente aislado del conjunto de datos utilizado para entrenar el modelo, lo que ayuda en última instancia a juzgar e informar la precisión del modelo. En el paso final del proyecto, este conjunto de datos de evaluación se utiliza para confirmar la precisión del modelo. Aquí, el 20% de los datos se separa como conjunto de datos de evaluación y el 80% de los datos se utiliza como conjunto de datos de entrenamiento.

El código se muestra a continuación:

#分离数据集,分离出 20%的数据作为评估数据集,80%的数据作为训练数据集
array = data.values
X = array[:, 0:13]
Y = array[:, 13]
validation_size = 0.2
seed = 7
X_train,X_validation,Y_train,Y_validation = train_test_split(X,Y,test_size = validation_size,random_state=seed)

6 Algoritmo de evaluación

Después de analizar los datos, no se puede elegir inmediatamente qué algoritmo es más eficaz para el problema que se debe resolver. Creemos intuitivamente que debido a la distribución lineal de algunos datos, los algoritmos de regresión lineal y los algoritmos de regresión de red elástica pueden ser más efectivos para resolver problemas. Además, debido a la discretización de los datos, es posible generar modelos de alta precisión mediante el algoritmo de árbol de decisión o el algoritmo de máquina de vectores de soporte.

En este punto, todavía no está claro qué algoritmo generará el modelo más preciso, por lo que es necesario diseñar un marco de evaluación para seleccionar el algoritmo apropiado. Utilizamos validación cruzada diez veces para separar los datos y comparar la precisión de los algoritmos mediante el error cuadrático medio. Cuanto más cerca esté el error cuadrático medio de 0, mayor será la precisión del algoritmo.

El código se muestra a continuación:

seed = 7
num_folds = 10
scoring = 'neg_mean_squared_error'

No se realiza ningún procesamiento de los datos originales y el algoritmo se evalúa para formar un punto de referencia de evaluación del algoritmo. Este valor de referencia es el valor de referencia para comparar los pros y los contras de mejoras posteriores del algoritmo. Seleccionamos tres algoritmos lineales y tres algoritmos no lineales para comparar.

Algoritmos lineales: regresión lineal (LR), regresión de lazo (LASSO) y regresión de red elástica (EN).
Algoritmos no lineales: Árbol de clasificación y regresión (CART), Máquina de vectores de soporte (SVM) y K Algoritmo vecino más cercano (KNN).

El código para inicializar el modelo de algoritmo es el siguiente:

#评估算法

models = {
    
    }

models['LR'] = LogisticRegression()
models['LASSO'] = Lasso()
models['EN'] = ElasticNet()
models['KNN'] = KNeighborsClassifier()
models['CART'] = DecisionTreeClassifier()
models['SVM'] = SVR()



X_train,X_validation,Y_train,Y_validation = train_test_split(X,Y,test_size = validation_size,random_state=seed)

results = []

for key in models:

  kflod = KFold(n_splits=num_folds,random_state=seed,shuffle=True)
  result = cross_val_score(models[key], X_train, Y_train.astype('int'), cv=kflod,scoring= scoring)
  results.append(result)
  print("%s: %.3f (%.3f)" % (key, result.mean(), result.std()))

A juzgar por los resultados de la ejecución, Lasso Regression (LASSO) tiene el MSE óptimo, seguido del algoritmo Elastic Network Regression (EN). Los resultados de la ejecución son los siguientes:

LR: -59.150 (17.584)
LASSO: -27.313 (13.573)
EN: -28.251 (13.577)
KNN: -62.158 (28.251)
CART: -31.000 (19.562)
SVM: -68.676 (33.776)

Luego observe los resultados de todas las validaciones cruzadas de 10 veces. El código se muestra a continuación:

#评估算法箱线图

fig = pyplot.figure()
fig.suptitle("Algorithm Comparison")
ax = fig.add_subplot(111)
pyplot.boxplot(results)
ax.set_xticklabels(models.keys())
pyplot.show()


Los resultados de la ejecución se muestran en la figura. Se puede ver en la figura que la distribución del algoritmo lineal es relativamente similar
y la distribución de resultados del algoritmo del árbol de clasificación y regresión (CART) es muy compacta.

Insertar descripción de la imagen aquí

Algoritmo de evaluación: datos normalizados

Supongo que tal vez debido a que las unidades de medida de diferentes atributos de características en los datos originales son diferentes, los resultados de algunos algoritmos no son muy buenos. A continuación, estos algoritmos se evalúan nuevamente normalizando los datos. Aquí, el procesamiento de conversión de datos se realiza en el conjunto de datos de entrenamiento y todos los valores de las características de los datos se convierten en datos con "0" como valor mediano y desviación estándar como "1". Al normalizar datos, para evitar la fuga de datos, se utiliza Pipeline para normalizar los datos y evaluar el modelo. Para comparar con los resultados anteriores, aquí se adopta el mismo marco de evaluación para evaluar el modelo algorítmico.

El código se muestra a continuación:

#评估算法--正态化数据

pipelines ={
    
    }

pipelines['ScalerLR'] = Pipeline([('Scaler',StandardScaler()),('LR',LinearRegression())])
pipelines['ScalerLASSO'] = Pipeline([('Scaler',StandardScaler()),('LASSO',Lasso())])
pipelines['ScalerEN'] = Pipeline([('Scaler',StandardScaler()),('EN',ElasticNet())])
pipelines['ScalerKNN'] = Pipeline([('Scaler',StandardScaler()),('KNN',KNeighborsRegressor())])

pipelines['ScalerCART'] = Pipeline([('Scaler',StandardScaler()),('CART',DecisionTreeRegressor())])
pipelines['ScalerSVM'] = Pipeline([('Scaler',StandardScaler()),('SVM',SVR())])


X_train,X_validation,Y_train,Y_validation = train_test_split(X,Y,test_size = validation_size,random_state=seed)

results = []


for key in pipelines:

  kflod = KFold(n_splits=num_folds,random_state=seed,shuffle=True)
  cv_result = cross_val_score(pipelines[key], X_train, Y_train, cv=kflod,scoring= scoring)
  results.append(cv_result)
  print("%s: %.3f (%.3f)" % (key, cv_result.mean(), cv_result.std()))

Después de la ejecución, se encontró que el algoritmo K vecino más cercano tiene el MSE óptimo. Los resultados de la ejecución son los siguientes:

ScalerLR: -22.006 (12.189)
ScalerLASSO: -27.206 (12.124)
ScalerEN: -28.301 (13.609)
ScalerKNN: -21.457 (15.016)
ScalerCART: -27.813 (20.786)
ScalerSVM: -29.570 (18.053)

A continuación, echemos un vistazo a los resultados de toda la validación cruzada de 10 veces. El código se muestra a continuación:

#评估算法箱线图

fig = pyplot.figure()
fig.suptitle("Algorithm Comparison")
ax = fig.add_subplot(111)
pyplot.boxplot(results)
ax.set_xticklabels(models.keys())
pyplot.show()

Los resultados de la ejecución y el diagrama de caja generado se muestran en la figura. Se puede ver que el algoritmo vecino más cercano K tiene el MSE óptimo y la distribución de datos más compacta.

Insertar descripción de la imagen aquí
En la actualidad, el algoritmo K vecino más cercano tiene buenos resultados para conjuntos de datos que han sido transformados, pero ¿se pueden optimizar aún más los resultados?

El número de parámetros predeterminado de vecinos (n_vecinos) del algoritmo K vecino más cercano es 5. Los parámetros se optimizan a continuación utilizando el algoritmo de búsqueda de cuadrícula. El código se muestra a continuación:


#调参改善算法-knn
scaler = StandardScaler().fit(X_train) # fit生成规则
#scaler = StandardScaler.fit(X_train)
rescaledX = scaler.transform(X_train)
param_grid ={
    
    'n_neighbors':[1,3,5,7,9,11,13,15,19,21]}
model = KNeighborsRegressor()

kflod = KFold(n_splits=num_folds,random_state=seed,shuffle=True)
grid = GridSearchCV(estimator =model,param_grid=param_grid,scoring=scoring,cv = kflod)

grid_result = grid.fit(X=rescaledX,y = Y_train)

print('最优:%s 使用%s'%(grid_result.best_score_,grid_result.best_params_))

cv_results = zip(grid_result.cv_results_['mean_test_score'],
                 grid_result.cv_results_['params'])


for mean,param in cv_results:
  print(mean,param)

Resultado óptimo: el número de parámetros predeterminado de vecinos (n_vecinos) del algoritmo K vecino más cercano es 1. Los resultados de la ejecución
son los siguientes:

最优:-19.497828658536584 使用{
    
    'n_neighbors': 1}
-19.497828658536584 {
    
    'n_neighbors': 1}
-19.97798367208672 {
    
    'n_neighbors': 3}
-21.270966658536583 {
    
    'n_neighbors': 5}
-21.577291737182684 {
    
    'n_neighbors': 7}
-21.00107515055706 {
    
    'n_neighbors': 9}
-21.490306228582945 {
    
    'n_neighbors': 11}
-21.26853270313177 {
    
    'n_neighbors': 13}
-21.96809222222222 {
    
    'n_neighbors': 15}
-23.506900689142622 {
    
    'n_neighbors': 19}
-24.240302870416464 {
    
    'n_neighbors': 21}

Determinar el modelo final.

Hemos decidido utilizar el algoritmo de árbol aleatorio extremo (ET) para generar el modelo. A continuación, entrenaremos el algoritmo, generaremos el modelo y calcularemos la precisión del modelo. El código se muestra a continuación:

#训练模型

caler = StandardScaler().fit(X_train)
rescaledX = scaler.transform(X_train)
gbr = ExtraTreeRegressor()
gbr.fit(X=rescaledX,y=Y_train)
#评估算法模型

rescaledX_validation = scaler.transform(X_validation)
predictions = gbr.predict(rescaledX_validation)
print(mean_squared_error(Y_validation,predictions))

Los resultados de la ejecución son los siguientes:

14.392352941176469

Este ejemplo de proyecto completa un proyecto completo de aprendizaje automático desde la definición del problema hasta la generación final del modelo. A través de este proyecto, entendí la plantilla del proyecto de aprendizaje automático presentado en la sección anterior , así como todo el proceso de construcción de un modelo de aprendizaje automático.

Supongo que te gusta

Origin blog.csdn.net/hai411741962/article/details/132580367
Recomendado
Clasificación