Use WeightedRandomSampler to solve the problem of unbalanced data samples

Problem Description

The dataset samples are not balanced.

For example, in a binary classification task, the data labeled 0 accounts for 90%, and the data labeled 1 only accounts for 10%. Training the model with all the original data is likely to cause the model to have a certain "bias", and may also lead to Model training is very slow.

Balance data using WeightedRandomSampler

PyTorch official documentation: torch.utils.data.WeightedRandomSampler

The following is the description given by the official PyTorch documentation:

insert image description here

You can see the code examples given in the official documentation:

>>> list(WeightedRandomSampler([0.1, 0.9, 0.4, 0.7, 3.0, 0.6], 5, replacement=True))
[4, 4, 1, 4, 5]
>>> list(WeightedRandomSampler([0.9, 0.4, 0.05, 0.2, 0.3, 0.1], 5, replacement=False))
[0, 1, 4, 3, 2]

Combine this example to explain each parameter:

  • weights: The weight sequence used for sampling, and the sum of this sequence is not required to be 1.
    For example, [0.1, 0.9, 0.4, 0.7, 3.0, 0.6] in the first example means that the weight of the 0th sample (which can be understood as probability, but not equivalent) is 0.1, and the weight of the 1st sample is taken is 0.9, ..., the weight of taking the 4th sample is 3.0, ..., obviously, the probability of taking the 4th sample is the highest, then the 1st, 3rd, 5th, 2nd, 0th .
  • num_samples: The number of samples.
    For example, in the example, num_samples is 5, that is, the number of samples is 5, and the sequence length of the final output is also 5.
  • replacement: Whether the sample can be replaced.

application

The following is the code that uses WeightedRandomSampler to solve the problem of unbalanced data set samples.

from torch.utils.data import DataLoader, Dataset, WeightedRandomSampler


class MyDataset(Dataset):
    def __init__(self, data):
        self.data = data

    def __getitem__(self, index):
        return {
    
    
            "text": self.data[index]["text"],
            "label": self.data[index]["label"]
        }

    def __len__(self):
        return len(self.data)


if __name__ == '__main__':
    data = [
        {
    
    "text": "a", "label": 0}, {
    
    "text": "b", "label": 0}, {
    
    "text": "c", "label": 1}, {
    
    "text": "d", "label": 0},
        {
    
    "text": "e", "label": 0}, {
    
    "text": "f", "label": 0}, {
    
    "text": "g", "label": 0}, {
    
    "text": "h", "label": 0},
        {
    
    "text": "i", "label": 0}, {
    
    "text": "j", "label": 0}, {
    
    "text": "k", "label": 0}, {
    
    "text": "l", "label": 1}
    ]
    dataset = MyDataset(data)

    label_list = []
    for per_data in dataset:
        label_list.append(per_data["label"])
    print(f"label_list = {
      
      label_list}")

    weights = [1.0 / label_list.count(label) for label in label_list]
    print(f"weights = {
      
      weights}")

    sampler = WeightedRandomSampler(weights, len(dataset), replacement=True)
    train_loader = DataLoader(dataset, sampler=sampler, batch_size=4, shuffle=False, num_workers=0)
    for data in train_loader:
        print(data)

output:

label_list = [0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1]
weights = [0.1, 0.1, 0.5, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.5]
{
    
    'text': ['c', 'a', 'g', 'l'], 'label': tensor([1, 0, 0, 1])}
{
    
    'text': ['e', 'j', 'l', 'k'], 'label': tensor([0, 0, 1, 0])}
{
    
    'text': ['c', 'g', 'c', 'k'], 'label': tensor([1, 0, 1, 0])}

The data set has 12 pieces of data, including 10 data with a label of 0 and 2 data with a label of 1. It can be said that the data set is unbalanced.

First calculate the weight sequence, the weight of the label is 0 is 1 / 10 = 0.1 1 / 10 = 0.11/10=0.1 , the weight of label 1 is1 / 2 = 0.5 1 / 2=0.51/2=0.5 , the value of adding up all the weights with label 0 is equal to the value of adding all the weights with label 1, so the probability of randomly picking up label 0 and label 1 during sampling is equal, and finally the new Basically, the data in the dataloader is balanced, but because of random sampling, it is understandable that the data volume of the two tags will eventually differ by one or two. In short, it is much more balanced than the original data set.

Guess you like

Origin blog.csdn.net/Friedrichor/article/details/129901346