Introduction
torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=None, sampler=None, batch_sampler=None, num_workers=0, collate_fn=None, pin_memory=False, drop_last=False, timeout=0, worker_init_fn=None, multiprocessing_context=None, generator=None, *, prefetch_factor=None, persistent_workers=False, pin_memory_device=‘’)
Details: DataLoader
Implement each module yourself based on DataLoader
Code
MyDataset
Data
Based on the implementation in torch , the loading of personal data sets, such as image and label loading
SingleSampler
, based on the implementation in torch, Sampler
the loading of batch number images of data, for example, Batch_Size=4, implements the selection of 4 indexes from all data. As a group, we then perform image operations based on the image index and MyDataset
implement our own processing of batch_size data based on torch . It needs to be based on the processing of implementing Sampler, which is more flexible. The existence of will automatically overwrite the parameters in Note: The implementation of will conflict with the premise that the selected type will be automatically judged without implementation. The image data of batch_size will be packaged, and the images and labels of batch_size can be implemented during the traversal process. correspond__getitem__
MyBathcSampler
BatchSampler
SingleSampler
MyBatchSampler
DataLoader
batch_size
Sampler
shuffer
shuffer
sampler
sampler
collate_fn
sampler
from typing import Iterator, List
import torch
from torch.utils.data import BatchSampler
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
from torch.utils.data import Sampler
class MyDataset(Dataset):
def __init__(self) -> None:
self.data = torch.arange(20)
def __len__(self):
return len(self.data)
def __getitem__(self, index):
return self.data[index]
@staticmethod
def collate_fn(batch):
return torch.stack(batch, 0)
class MyBatchSampler(BatchSampler):
def __init__(self, sampler: Sampler[int], batch_size: int) -> None:
self._sampler = sampler
self._batch_size = batch_size
def __iter__(self) -> Iterator[List[int]]:
batch = []
for idx in self._sampler:
batch.append(idx)
if len(batch) == self._batch_size:
yield batch
batch = []
yield batch
def __len__(self):
return len(self._sampler) // self._batch_size
class SingleSampler(Sampler):
def __init__(self, data_source) -> None:
self._data = data_source
self.num_samples = len(self._data)
def __iter__(self):
# 顺序采样
# indices = range(len(self._data))
# 随机采样
indices = torch.randperm(self.num_samples).tolist()
return iter(indices)
def __len__(self):
return self.num_samples
train_set = MyDataset()
single_sampler = SingleSampler(train_set)
batch_sampler = MyBatchSampler(single_sampler, 8)
train_loader = DataLoader(train_set, batch_size=4, sampler=single_sampler, pin_memory=True, collate_fn=MyDataset.collate_fn)
for data in train_loader:
print(data)
batch_sampler
from typing import Iterator, List
import torch
from torch.utils.data import BatchSampler
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
from torch.utils.data import Sampler
class MyDataset(Dataset):
def __init__(self) -> None:
self.data = torch.arange(20)
def __len__(self):
return len(self.data)
def __getitem__(self, index):
return self.data[index]
@staticmethod
def collate_fn(batch):
return torch.stack(batch, 0)
class MyBatchSampler(BatchSampler):
def __init__(self, sampler: Sampler[int], batch_size: int) -> None:
self._sampler = sampler
self._batch_size = batch_size
def __iter__(self) -> Iterator[List[int]]:
batch = []
for idx in self._sampler:
batch.append(idx)
if len(batch) == self._batch_size:
yield batch
batch = []
yield batch
def __len__(self):
return len(self._sampler) // self._batch_size
class SingleSampler(Sampler):
def __init__(self, data_source) -> None:
self._data = data_source
self.num_samples = len(self._data)
def __iter__(self):
# 顺序采样
# indices = range(len(self._data))
# 随机采样
indices = torch.randperm(self.num_samples).tolist()
return iter(indices)
def __len__(self):
return self.num_samples
train_set = MyDataset()
single_sampler = SingleSampler(train_set)
batch_sampler = MyBatchSampler(single_sampler, 8)
train_loader = DataLoader(train_set, batch_sampler=batch_sampler, pin_memory=True, collate_fn=MyDataset.collate_fn)
for data in train_loader:
print(data)
reference
Sampler:https://blog.csdn.net/lidc1004/article/details/115005612