机器学习平台系列——XGB feature_names mismatch 问题解决方案

最近开发公司的机器学习平台的XGBoost控件。结果报了一个bug,说“feature_names mismatch”。

现在我们来复现这个bug:

import xgboost as xgb
import pandas as pd
from sklearn.datasets import make_classification
from sklearn.model_selection import train_test_split

X, y = make_classification(n_features=4)
X_train, X_test, y_train, y_test = train_test_split(X, y)
features = ['a1', 'a2', 'a3', 'a4']
df = pd.DataFrame(data=X_train, columns=features)
df['y'] = y_train

X_train = df[['a1', 'a2', 'a3', 'a4']]
y_train = df['y']
model = xgb.XGBClassifier()
model.fit(X_train, y_train)

以上代码随机生成了一个分类数据集,它的特征的名字是a1, a2, a3, a4。我们用pandas的数据集,而不是numpy的数列,传给XGB的fit函数来训练模型。这里还没有出现bug。

然而在预测的时候,出现了bug。

model.predict(X_test)

bug的描述如下:

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<ipython-input-23-decd67bbf386> in <module>
----> 1 model.predict(X_test)

C:\ProgramData\Anaconda3\lib\site-packages\xgboost\sklearn.py in predict(self, data, output_margin, ntree_limit, validate_features, base_margin)
    882         if ntree_limit is None:
    883             ntree_limit = getattr(self, "best_ntree_limit", 0)
--> 884         class_probs = self.get_booster().predict(
    885             test_dmatrix,
    886             output_margin=output_margin,

C:\ProgramData\Anaconda3\lib\site-packages\xgboost\core.py in predict(self, data, output_margin, ntree_limit, pred_leaf, pred_contribs, approx_contribs, pred_interactions, validate_features, training)
   1569 
   1570         if validate_features:
-> 1571             self._validate_features(data)
   1572 
   1573         length = c_bst_ulong()

C:\ProgramData\Anaconda3\lib\site-packages\xgboost\core.py in _validate_features(self, data)
   2128                             ', '.join(str(s) for s in my_missing))
   2129 
-> 2130                 raise ValueError(msg.format(self.feature_names,
   2131                                             data.feature_names))
   2132 

ValueError: feature_names mismatch: ['a1', 'a2', 'a3', 'a4'] ['f0', 'f1', 'f2', 'f3']
expected a1, a3, a2, a4 in input data
training data did not have the following fields: f2, f1, f3, f0

遇到Bug,首先就是去stackoverflow上面找答案,很快找到了解决方案。可以在模型预测的时候加上validate_features=False

model.predict(X_test, validate_features=False)

这里问题来了。平时如果我们自己建模,到这里,问题就解决了。但是在机器学习平台中,预测是另外一个组件。如果我改了预测组件的代码,可能会影响到很多地方,可能会引起更多的bug。

于是,有的人可能会这样做:

if is_xgb:
    model.predict(X_test, validate_features=False)
else:
    model.predict(X_test)

我一开始学编程的时候,也是这样解决问题的。但其实还有更好的办法,一个可以不修改预测组件的办法。

我们看到,XGB和其他算法的【行为不一致】。其他算法,包括sklearn里面的主要算法,我们可以用pandas数据集训练模型,然后用numpy数列做预测。整个系统也是基于sklearn里面的模型的这种行为设计的。可是偏偏XGB的行为方式和大家不一样。在一个系统当中,应该要求每个模块都遵循一定的规则。这个就是为什么在java和其他语言中,我们会使用继承和接口。所以,我们这里要做的是,修改XGB的predict函数,在不给定validate_features=False的情况下,也实现一样的功能。所以,这里要怎么办呢?

有人会说,可以修改XGB的源代码。也可以,就是难度太大。有没有简单一点的方法?

(在我继续之前,我特地放了一张一休的图片,让大家闭上眼睛想一想。)

一休

我提醒你一下,我们在连接投影仪的时候,经常碰到电脑显卡的插口和投影仪的不一致。这个时候怎么办?是不是可以用一个名字叫adapter的中间设备来转换一下?

在这里插入图片描述

我的方法,就是用一个adapter类,套在原始的XGBClassifier的外面。用新类的predict调用老类的predict,在调用的时候,加上validate_features=False。具体如下:

class XGBClassifierAdapter():
    model = None
    
    def __init__(self, **params):
        self.model = xgb.XGBClassifier(**params)
        
    def fit(self, X, y):
        self.model.fit(X, y)
        
    def predict(self, X):
        return self.model.predict(X, validate_features=False)

这时,我们用这个新类去替换老类,就不会报错了。

X, y=make_classification(n_features=4)
X_train, X_test, y_train, y_test = train_test_split(X, y)
features=['a1', 'a2', 'a3', 'a4']
df=pd.DataFrame(data=X_train, columns=features)
df['y']=y_train

X_train = df[['a1', 'a2', 'a3', 'a4']]
y_train = df['y']
model = XGBClassifierAdapter()
model.fit(X_train, y_train)
model.predict(X_test)

这种方法,有个学名,叫adapter,是一种设计模式。它将一个类的接口转换成另外一个接口,使得原本由于接口不兼容而不能一起工作的那些类可以一起工作。

(这里转化前后的函数名字都叫predict,但是他们其实不一样。)

也就是说XGBClassifierAdepter要满足两个条件。第一它要改变XGBClassifier的接口,那么它就要依赖于这个类。另外,它又要和其他的sklearn的类接口要一致,或者说他们要具有共同的父类/接口,我们这里想象一个接口叫Classifier,那么XGBClassifierAdepter也要同时继承这个父类/接口。

在这里插入图片描述

猜你喜欢

转载自blog.csdn.net/juwikuang/article/details/108293148
今日推荐