Python imblearn 解决 类别不平衡问题

0. 问题背景及解决方法

类别不平衡问题
       类别不平衡问题,顾名思义,即数据集中存在某一类样本,其数量远多于或远少于其他类样本,从而导致一些机器学习模型失效的问题。例如逻辑回归即不适合处理类别不平衡问题,例如逻辑回归在欺诈检测问题中,因为绝大多数样本都为正常样本,欺诈样本很少,逻辑回归算法会倾向于把大多数样本判定为正常样本,这样能达到很高的准确率,但是达不到很高的召回率。

       类别不平衡问题在很多场景中存在,例如欺诈检测,风控识别,在这些样本中,黑样本(一般为存在问题的样本)的数量一般远少于白样本(正常样本)。

       上采样(过采样)和下采样(负采样)策略是解决类别不平衡问题的基本方法之一。上采样即增加少数类样本的数量,下采样即减少多数类样本以获取相对平衡的数据集。

       最简单的上采样方法可以直接将少数类样本复制几份后添加到样本集中,最简单的下采样则可以直接只取一定百分比的多数类样本作为训练集。

       SMOTE算法是用的比较多的一种上采样算法,SMOTE算法的原理并不是太复杂,用python从头实现也只有几十行代码,但是python的imblearn包提供了更方便的接口,在需要快速实现代码的时候可直接调用imblearn。
————————————————
版权声明:本文为CSDN博主「nlpuser」的原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接及本声明。
原文链接:https://blog.csdn.net/nlpuser/article/details/81265614

1. imblearn安装

前置条件

  • numpy (>=1.11)
  • scipy (>=0.17)
  • scikit-learn (>=0.21)
  • keras 2 (optional)
  • tensorflow (optional)

安装命令

目前,PyPi的库中已经提供了imbalanced-learn,因此,可以使用pip命令实现安装:

pip install imbalanced-learn

* 附上官网链接

https://imbalanced-learn.readthedocs.io/en/stable/generated/imblearn.over_sampling.SMOTE.html#imblearn.over_sampling.SMOTE

2. 测试样例

安装完成后,使用sklearn的make_classification造一些类别不均衡的数据,其中y0:y1=1:9。

为方便展示,使用2维数据。

from collections import Counter
from sklearn.datasets import make_classification
from imblearn.over_sampling import SMOTE # doctest: +NORMALIZE_WHITESPACE

# n_informative:有用特征数量,n_redundant冗余特征数量
X, y = make_classification(n_classes=2, class_sep=2, weights=[0.1, 0.9], 
                           n_informative=2, n_redundant=0, flip_y=0, 
                           n_features=2, n_clusters_per_class=2, 
                           n_samples=200, random_state=10)

print('Original dataset shape %s' % Counter(y))

sm = SMOTE(random_state=42)
X_res, y_res = sm.fit_resample(X, y)
print('Resampled dataset shape %s' % Counter(y_res))


# plot
from matplotlib import pyplot as plt
from matplotlib.colors import ListedColormap
fig,ax = plt.subplots()
cm = plt.cm.RdBu
cm_bright = ListedColormap(['#FF0000', '#0000FF'])
                            
ax.set_title("test for SMOTE")

# Plot the Resampled dataset points
ax.scatter(X_res[:, 0], X_res[:, 1], c=y_res, cmap=cm_bright, alpha=0.2,
           edgecolors='k')
# Plot the Original dataset points
ax.scatter(X[:, 0], X[:, 1], c=y, cmap=cm_bright,
           edgecolors='k')

输出:

Original dataset shape Counter({1: 180, 0: 20})
Resampled dataset shape Counter({0: 180, 1: 180})

浅色点代表生成的数据,深色点代表原始数据。 

发布了28 篇原创文章 · 获赞 5 · 访问量 4057

猜你喜欢

转载自blog.csdn.net/authorized_keys/article/details/103026814