【深度好文】Pytorch不均衡数据集采样器

1. 数据不均衡问题引入

在很多机器学习的应用中,我们经常会遇到一些数据集中某类数据样本数量相比其他类别数量多很多。我们以罕见病诊断为例,一般来说正常样本数量比生病样本数量多。在样本不均衡的情况下,我们需要确保经过训练的模型不会偏向于拥有更多数据量的类别。
举个例子,假设上述罕见病诊断的数据集里一共含有25张图像,其中5个为患病样本图像,剩下20个为正常样本图像。假设我们的模型预测所有图像均为正常,此时我们可以计算我们模型的评测指标如下:
请添加图片描述相应的准确率和召回率如下:
请添加图片描述如上所示,上述模型的准确率达到了80%,同时这样一个模型的F1-score达到了0.88。因此,该模型有很高的趋势来倾向于预测正常类别。

2. 数据重采样

为了解决上述数据不均衡的问题,一个被广泛采用的技术叫做数据重采样。它包含从多数类别的样本中删除个别样本(下采样)以及往少数类别的样本中添加更多样本(过采样)。尽管上述两种方法可以保持样本类别数量的均衡,他们也存在各自的缺点,毕竟天下没有免费的午餐。实现过采样一种简单的方式是重复复制数量较少的类别的样本,这有可能会导致模型过拟合; 同时实现下采样最简单的方式是删除数量较多的某类样本,这可能会导致训练过程信息丢失。下采样和过采样的示意图如下所示:
请添加图片描述

图1: 左-下采样示意图 右-过采样示意图


3. 不均衡数据集采样器

我们实现了一个易于使用的PyTorch采样器ImbalancedDatasetSampler,如下所示
请添加图片描述

图2: 不均衡采样示意图



它具有以下特点:

  • 能够从不平衡的数据集采样后重新平衡类间分布
  • 能够自动估计采样时的权值
  • 避免创建新的平衡数据集
  • 当它与数据增强技术一起使用时,可以减轻过拟合

4. 使用方法

首先,通过pip方式进行安装,安装命令如下:

pip install https://github.com/ufoym/imbalanced-dataset-sampler/archive/master.zip

在代码中使用上述sampler,仅需要将 ImbalancedDatasetSampler 作为创建DataLoader时的相关参数即可。举例如下:

from torchsampler import ImbalancedDatasetSampler
train_loader = torch.utils.data.DataLoader(
    train_dataset,
    sampler=ImbalancedDatasetSampler(train_dataset),
    batch_size=args.batch_size,
    **kwargs
)

5. 性能验证

我们以不均衡手写字符分类数据集为例,来说明上述采样器的具体性能。
我们构造的不均衡手写字符分类数据集的分布如下:
在这里插入图片描述

图3:不均衡手写字符识别数据集样本分布

我们使用普通的sampler来训练模型,得到模型评估性能如下:
在这里插入图片描述

图4:左- acc随epoch变化曲线 右-混淆矩阵

如果使用上述 ImbalancedDatasetSampler来训练模型,得到的模型评估性能如下:
ImbalancedDatasetSampler
图5:左- acc随epoch变化曲线 右-混淆矩阵

6. 结论

注意上述使用 ImbalancedDatasetSampler 后,对于样本数量较少的类别比如2,6,9类在acc上均有明显的提升,同时样本数量较多的其他类别,acc基本保持不变。
Wow,推荐大家在日常工作生活中积极使用。

7. 参考

参考链接:
链接一

猜你喜欢

转载自blog.csdn.net/sgzqc/article/details/118965083#comments_27347846