"Derivación de fórmulas de aprendizaje automático e implementación de código" capítulo 14-CatBoost

Notas de estudio "Derivación de fórmulas de aprendizaje automático e implementación de código", registre su propio proceso de aprendizaje, compre el libro del autor para obtener contenido detallado.

impulso gato

CatBoost es un marco informático GBDT de código abierto creado por el gigante ruso del motor de búsqueda Yandex en 2017. Se llama CatBoost( Categorical Boosting) porque puede procesar de manera eficiente las características de categoría en los datos.

1 Método de procesamiento de características de categoría en aprendizaje automático

CatBoost mejora los métodos estadísticos de variables objetivo convencionales al agregarles anteriores. Además, CatBoost también considera el uso de diferentes combinaciones de características de categoría para aumentar la dimensión de características del conjunto de datos.

Para características de categoría con una gran cantidad de valores de características, un método de compromiso es reclasificar la cantidad de categorías para reducir el número a un número más pequeño y luego realizar una codificación one-hot. Otro método comúnmente usado es 目标变量统计( target statistics, TS), TS calcula el valor esperado de cada categoría para la variable objetivo y convierte las características de la categoría en nuevas características numéricas. CatBoost mejora el método TS convencional.

2 Base teórica de CatBoost

CatBoost目标变量统计Las características teóricas propias del marco del algoritmo, incluidas , 特征组合y , se utilizan para tratar con variables categóricas 排序提升算法.

2.1 Estadísticas de variables objetivo

CatBoostLa intención original del diseño del algoritmo es manejar mejor las características de GBDT categorical features. Cuando se trata de funciones categóricas en funciones GBDT, la forma más fácil es reemplazarlas con el valor promedio de la etiqueta correspondiente a la función categórica. En un árbol de decisión, el promedio de etiquetas se utilizará como criterio para la división de nodos. Este método se denomina Greedy Target-based Statisticsabreviado Greedy TSy se expresa mediante la fórmula:
x ^ ki = ∑ j = 1 n [ xj , k = xi , k ] Y i ∑ j = 1 n [ xj , k = xi , k ] \ hat{x} _{k}^{i} =\frac{\sum_{j=1}^{n}\left [ x_{j,k} =x_{i,k} \right ]Y_{i} }{\sum_ {j=1}^{n} \izquierda [x_{j,k} =x_{i,k} \derecha]}X^kyo=j = 1n[ Xj , k=Xyo , k]j = 1n[ Xj , k=Xyo , k]Yyo
Este método tiene un defecto evidente, es decir, normalmente las características contienen más información que las etiquetas, si se utiliza el valor promedio de las etiquetas para representar las características, aparecerá cuando la estructura de datos y la distribución del conjunto de datos de entrenamiento y el el conjunto de datos de prueba es diferente Problema de cambio condicional.

Una forma estándar de mejorar Greedy TS es agregar un término de distribución previo, que puede reducir el impacto del ruido y los datos de categorías de baja frecuencia en la distribución de datos:
x ^ ki = ∑ j = 1 p − 1 [ x σ j , k = x σ pags , k ] Y σ j + α pags ∑ j = 1 pags − 1 [ x σ j , k = x σ pags , k ] + α \hat{x}_{k}^{i} =\ frac{ \sum_{j=1}^{p-1}\left [ x_{\sigma _{j,k} } =x_{\sigma _{p,k} } \right ]Y_{\sigma_{ j} } + \alpha p}{\sum_{j=1}^{p-1} \left [ x_{\sigma _{j,k} } =x_{\sigma _{p,k} } \right ]+ \alfa}X^kyo=j = 1pag 1[ Xpagj , k=Xpagpag , k]+aj = 1pag 1[ Xpagj , k=Xpagpag , k]Ypagj+una p
donde p es el término anterior agregado y α suele ser un coeficiente de ponderación mayor que 0. Agregar un término anterior es una práctica común, que puede reducir los datos ruidosos para funciones con una pequeña cantidad de categorías. Para problemas de regresión, en general, el elemento anterior puede tomar el valor medio de la etiqueta del conjunto de datos. Para la clasificación binaria, el término previo es la probabilidad previa de ejemplos positivos. Las permutaciones que utilizan múltiples conjuntos de datos también son eficientes; sin embargo, pueden provocar un sobreajuste si se calculan directamente.

CatBoost utiliza un método relativamente novedoso para calcular los valores de los nodos hoja. Este método ( oblivious trees, árbol simétrico) puede evitar el problema del sobreajuste en los cálculos directos en la disposición de múltiples conjuntos de datos.

2.2 Combinación de funciones

Vale la pena señalar que cualquier combinación de varias características categóricas puede considerarse como una característica nueva. Por ejemplo, en una aplicación de recomendación de música, tenemos dos características categóricas: ID de usuario y género musical. Si algunos usuarios prefieren la música rock, al convertir la ID de usuario y el género musical en funciones digitales, se perderá la información de acuerdo con lo anterior.

La combinación de estas dos funciones resuelve este problema y produce una nueva función poderosa. Sin embargo, la cantidad de combinaciones crece exponencialmente con la cantidad de características categóricas en el conjunto de datos, por lo que es imposible considerar todas las combinaciones en el algoritmo.

Al construir nuevos puntos de división para el árbol actual, CatBoost considera combinaciones en una estrategia codiciosa. Para la primera división del árbol, no se considera ninguna combinación. Para la siguiente división, CatBoost combina todas las características categóricas combinadas del árbol actual con todas las características categóricas del conjunto de datos y convierte dinámicamente las nuevas características categóricas combinadas en características numéricas.

2.3 Algoritmo de refuerzo de clasificación

Para aprender a predecir compensaciones, tengo dos preguntas:

  • ¿Qué es el sesgo de pronóstico?
  • ¿Cuál es la solución al problema del sesgo de predicción?

El cambio de predicción ( Prediction shift) es causado por un sesgo de gradiente. En cada iteración de GDBT, la función de pérdida utiliza el mismo conjunto de datos para obtener el gradiente del modelo actual y luego entrena para obtener el alumno base, pero esto conducirá a una desviación de la estimación del gradiente, lo que a su vez conduce al problema de sobre -ajuste del modelo.

Ordered boostingCatBoost reemplaza el método de estimación de gradiente en el algoritmo tradicional mediante el uso de refuerzo de clasificación ( ), lo que reduce la desviación de la estimación de gradiente y mejora la capacidad de generalización del modelo.

CatBoostSe utiliza un árbol simétrico como clasificador base, y la simetría significa que en la misma capa del árbol, los criterios de división son los mismos. Los árboles simétricos están equilibrados, no son propensos a sobreajustarse y pueden reducir en gran medida el tiempo de prueba.

3 Implementación del algoritmo CatBoost

Como algoritmo de impulso con la misma reputación que XGBoost y LightGBM, CatBoost tiene indicadores de rendimiento suficientemente buenos, especialmente para el procesamiento de características de categoría.

import pandas as pd
data = pd.read_csv('./adult.data', header=None)
data

inserte la descripción de la imagen aquí

data.columns = ['age', 'workclass', 'fnlwgt', 'education', 'education-num', 'marital-status', 'occupation', 'relationship', 'race',
                'sex', 'capital-gain', 'capital-loss', 'hours-per-week', 'native-country', 'income'] # 变量重命名
data['income']
0         <=50K
1         <=50K
2         <=50K
3         <=50K
4         <=50K
          ...  
32556     <=50K
32557      >50K
32558     <=50K
32559     <=50K
32560      >50K
Name: income, Length: 32561, dtype: object
data['income'] = data['income'].astype('category').cat.codes
data['income'].unique()
array([0, 1], dtype=int8)
from sklearn.model_selection import train_test_split
import catboost as cb
from sklearn.metrics import accuracy_score
X_train, X_test, y_train, y_test = train_test_split(data.drop(['income'], axis=1), data['income'], random_state=10, test_size=0.3)
clf = cb.CatBoostClassifier(eval_metric='AUC', depth=4, iterations=500, l2_leaf_reg=1, learning_rate=0.1)
cat_features_index = [1, 3, 5, 6, 7, 8, 9, 13] # 设置分类特征的索引,以便 CatBoost 能够正确地识别这些特征
clf.fit(X_train, y_train, cat_features=cat_features_index)
y_pred = clf.predict(X_test)
print(accuracy_score(y_pred, y_test))
0:	total: 274ms	remaining: 2m 16s
1:	total: 337ms	remaining: 1m 23s
2:	total: 384ms	remaining: 1m 3s
3:	total: 434ms	remaining: 53.8s
4:	total: 485ms	remaining: 48s
5:	total: 558ms	remaining: 45.9s
6:	total: 596ms	remaining: 41.9s
7:	total: 642ms	remaining: 39.5s
8:	total: 676ms	remaining: 36.9s
9:	total: 712ms	remaining: 34.9s
10:	total: 748ms	remaining: 33.3s
11:	total: 782ms	remaining: 31.8s
12:	total: 816ms	remaining: 30.6s
13:	total: 854ms	remaining: 29.6s
14:	total: 896ms	remaining: 29s
15:	total: 941ms	remaining: 28.4s
16:	total: 981ms	remaining: 27.9s
17:	total: 1.02s	remaining: 27.3s
18:	total: 1.06s	remaining: 26.8s
19:	total: 1.1s	remaining: 26.4s
20:	total: 1.14s	remaining: 26s
21:	total: 1.18s	remaining: 25.6s
22:	total: 1.22s	remaining: 25.2s
23:	total: 1.25s	remaining: 24.8s
24:	total: 1.28s	remaining: 24.4s
...
497:	total: 18s	remaining: 72.4ms
498:	total: 18.1s	remaining: 36.2ms
499:	total: 18.1s	remaining: 0us
0.8721465861398301

Dirección de Notebook_Github

おすすめ

転載: blog.csdn.net/cjw838982809/article/details/131340125