Detailed explanation of the process of establishing MyDataLoader in Pytorch

Introduction

torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=None, sampler=None, batch_sampler=None, num_workers=0, collate_fn=None, pin_memory=False, drop_last=False, timeout=0, worker_init_fn=None, multiprocessing_context=None, generator=None, *, prefetch_factor=None, persistent_workers=False, pin_memory_device=‘’)

Details: DataLoader

Implement each module yourself based on DataLoader

Code

MyDatasetDataBased on the implementation in torch , the loading of personal data sets, such as image and label loading
SingleSampler, based on the implementation in torch, Samplerthe loading of batch number images of data, for example, Batch_Size=4, implements the selection of 4 indexes from all data. As a group, we then perform image operations based on the image index and MyDatasetimplement our own processing of batch_size data based on torch . It needs to be based on the processing of implementing Sampler, which is more flexible. The existence of will automatically overwrite the parameters in Note: The implementation of will conflict with the premise that the selected type will be automatically judged without implementation. The image data of batch_size will be packaged, and the images and labels of batch_size can be implemented during the traversal process. correspond__getitem__
MyBathcSamplerBatchSamplerSingleSamplerMyBatchSamplerDataLoaderbatch_size
Samplershuffershuffersamplersampler
collate_fn
Insert image description here

sampler

from typing import Iterator, List
import torch
from torch.utils.data import BatchSampler
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
from torch.utils.data import Sampler


class MyDataset(Dataset):
    def __init__(self) -> None:
        self.data = torch.arange(20)
        
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, index):
        return self.data[index]
    
    @staticmethod
    def collate_fn(batch):
        return torch.stack(batch, 0)

class MyBatchSampler(BatchSampler):
    def __init__(self, sampler: Sampler[int], batch_size: int) -> None:
        self._sampler = sampler
        self._batch_size = batch_size
    
    def __iter__(self) -> Iterator[List[int]]:
        batch = []
        for idx in self._sampler:
            batch.append(idx)
            if len(batch) == self._batch_size:
                yield batch
                batch = []
        yield batch
    
    def __len__(self):
        return len(self._sampler) // self._batch_size

class SingleSampler(Sampler):
    def __init__(self, data_source) -> None:
        self._data = data_source
        self.num_samples = len(self._data)
        
    def __iter__(self):
        # 顺序采样
        # indices = range(len(self._data))
        # 随机采样
        indices = torch.randperm(self.num_samples).tolist()
        return iter(indices)
    
    def __len__(self):
        return self.num_samples
        

train_set = MyDataset()
single_sampler = SingleSampler(train_set)
batch_sampler = MyBatchSampler(single_sampler, 8)
train_loader = DataLoader(train_set, batch_size=4, sampler=single_sampler, pin_memory=True, collate_fn=MyDataset.collate_fn)
for data in train_loader:
    print(data)

batch_sampler

from typing import Iterator, List
import torch
from torch.utils.data import BatchSampler
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
from torch.utils.data import Sampler

class MyDataset(Dataset):
    def __init__(self) -> None:
        self.data = torch.arange(20)
        
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, index):
        return self.data[index]
    
    @staticmethod
    def collate_fn(batch):
        return torch.stack(batch, 0)

class MyBatchSampler(BatchSampler):
    def __init__(self, sampler: Sampler[int], batch_size: int) -> None:
        self._sampler = sampler
        self._batch_size = batch_size
    
    def __iter__(self) -> Iterator[List[int]]:
        batch = []
        for idx in self._sampler:
            batch.append(idx)
            if len(batch) == self._batch_size:
                yield batch
                batch = []
        yield batch
    
    def __len__(self):
        return len(self._sampler) // self._batch_size

class SingleSampler(Sampler):
    def __init__(self, data_source) -> None:
        self._data = data_source
        self.num_samples = len(self._data)
        
    def __iter__(self):
        # 顺序采样
        # indices = range(len(self._data))
        # 随机采样
        indices = torch.randperm(self.num_samples).tolist()
        return iter(indices)
    
    def __len__(self):
        return self.num_samples
        

train_set = MyDataset()
single_sampler = SingleSampler(train_set)
batch_sampler = MyBatchSampler(single_sampler, 8)
train_loader = DataLoader(train_set, batch_sampler=batch_sampler, pin_memory=True, collate_fn=MyDataset.collate_fn)
for data in train_loader:
    print(data)

reference

Sampler:https://blog.csdn.net/lidc1004/article/details/115005612

Guess you like

Origin blog.csdn.net/frighting_ing/article/details/132402386