WeightedRandomSampler detailed explanation and example analysis

Application scenarios:

我们常常会遇到--数据集不平衡--问题,举个最简单的例子:猫狗分类,你有10000张狗的图片,却只有1000张猫的图片。这时如果直接利用整个数据集训练就很容易导致网络对‘狗’这个类别过拟合,在猫的识别任务上表现很差。

Solution: Use pytorch's WeightedRandomSampler , each epoch is weighted according to the number of each category, and samples of each category are sampled.

For example: Each epoch uses 1,600 images to train the network. Among them, dog samples have a low probability of being sampled, but the number is large. Cat samples have a high probability of being sampled, but the number is small. This makes the number of two categories in the 1,600 training images. It’s almost the same, and every epoch is sampled. If the training is enough, all pictures will basically be trained.


Official usage introduction

先看下官网的API
Insert image description here
weight : a list, each number measures the probability weight of the sample located at the index being sampled (the sum does not need to be 1)
num_samples : an integer, indicating the number you want to sample, such as the above cat and dog classification in each epoch in the entire data Centrally sample 1600 pictures
replacement : a bool variable, whether to repeat sampling. If it is True, a picture may be sampled multiple times.
generator : Useless, don’t worry about it.


Application examples:

猫狗分类,10000张狗图片,1000张猫图片。要实现:每个epoch只取1600张图片训练,且不重复采样

Key code:

from torch.utils.data import DataLoader
from torch.utils.data.sampler import WeightedRandomSampler

# 自定义的dataset,前10000个是dog,然后是1000个cat,共11000个数据。
train_dataset   = DataGenerator(train_imgs, input_shape=input_shape, train=True)
# 每个类别的样本数量
num_dog   	    = 10000
num_cat         = 1000
# weight列表
train_weights   = []
train_weights.extend([1/num_dog]*num_dog) # 扩展10000个0.0001,前10000个样本的采样权重,狗多权重小
train_weights.extend([1/num_cat]*num_cat)  # 扩展1000个0.001,后1000个样本的采样权重,猫少权重大
print(train_weights) # 可以打印看看,总共11000个数,前10000个是0.0001,然后是1000个0.001
# 创建WeightedRandomSampler,1600为采样数
train_sampler   = WeightedRandomSampler(train_weights, 1600, replacement=False)
# 将sampler传给Dataloader,不再需要shuffle
gen             = DataLoader(train_dataset, batch_size=128, sampler=train_sampler, num_workers=8)

Explanation: The weight list contains the sampling probability of each sample. 11,000 samples have 11,000 weights. In the example, the weights of dogs add up to 1, and the weights of cats also add up to 1. The sum of the weights of cats and dogs is equal, ensuring that each The number of image samples of the two types among the 1600 samples in each epoch is almost the same. Dataloader first uses sampler to sample 1,600 images from the dataset each time, and then divides them into piles of batches for training.

I’m afraid everyone won’t understand, so please forgive me for the lot of nonsense. If you have any questions, please share them in the comment area.

Guess you like

Origin blog.csdn.net/ittongyuan/article/details/131088803