Preface
Whether it is model training or actual testing, data reading is the first step, because deep learning is ultimately driven by data. If there is a program that can accurately read in the data, then combined with preprocessing, pre-trained models and optimizers in Pytorch, you can build a simple model for prediction. The core of Pytorch's data reading is the DataLoader method and the Dataset class. The framework of this note is mainly derived from Deep Eyes , and some related expansions have been made. The expanded content mainly comes from the translation and understanding of torch documents .
Data segmentation: Pytorch clocking in the Eye of Depth (6): Method of segmenting the data set into training set, validation set and test set
Dataset class
Dataset is an abstract class representing data, defining where and how the data is read. Where is the data read from? Of course, it is read from the hard disk, which is achieved by passing a path parameter to the Dataset. How to read the data? The reading method needs to be customized by us, and we have different reading methods for different data set division methods.
class DataSet(Dataset):
pass
We need to customize the class representing the data set to inherit the abstract class of Dataset. Then instantiate the Dataset subclass we created to represent the training set, validation set and test set data. Each subclass of Dataset must have replication __getitem__()
methods, and often optional replication __len__()
methods and __init__()
methods.
The basic framework is shown in the code:
class Dataset(Dataset):
def __init__(self):
...
def __getitem__(self, index):
return ...
def __len__(self):
return ...
- Method implementation and instantiation
__init__(self)
Used to add some attributes of the class itself, such as tags, data information, and whether to enhance data. __len__(self)
Used to return the size of the data set. __getitem__(self, index)
It is used to receive an index index
and return the corresponding data and label in the data set. It is the core of reading data and is generated index
by the sampler
class in DataLoader() .
The first data set division method in the previous note is realized by data reading:
import os
from PIL import Image
from torchvision import transforms
from torch.utils.data import Dataset
label_name = {
'ants': 0, 'bees': 1}
class DataSet(Dataset):
def __init__(self, data_path): # 除了这两个参数之外,还可以设置其它参数
self.label_name = {
'ants': 0, 'bees': 1}
self.data_info = get_info(data_path)
def __getitem__(self, index):
label, img_path = self.data_info[index]
pil_img = Image.open(img_path).convert('RGB') # 读数据
re_img = transforms.Resize((32, 32))(pil_img)
img = transforms.ToTensor()(re_img) # PIL转张量
return img, label
def __len__(self):
return len(self.data_info)
def get_info(data_path):
data_info = list()
for root_dir, sub_dirs, _ in os.walk(data_path):
for sub_dir in sub_dirs:
file_names = os.listdir(os.path.join(root_dir, sub_dir))
img_names = list(filter(lambda x: x.endswith('.jpg'), file_names))
for i in range(len(img_names)):
img_path = os.path.join(root_dir, sub_dir, img_names[i])
img_label = label_name[sub_dir]
data_info.append((img_label, img_path))
return data_info
if __name__ == '__main__':
train_set_path = os.path.join('data', 'train_set')
val_set_path = os.path.join('data', 'val_set')
test_set_path = os.path.join('data', 'test_set')
train_set = DataSet(data_path=train_set_path)
val_set = DataSet(data_path=val_set_path)
test_set = DataSet(data_path=test_set_path)
Breakpoint debugging results:
The second data set division method in the previous note is realized by data reading: mainly get_info()
the difference of functions.
import os
from PIL import Image
from torchvision import transforms
from torch.utils.data import Dataset
label_name = {
'ants': 0, 'bees': 1}
class DataSet(Dataset):
def __init__(self, data_path): # 除了这两个参数之外,还可以设置其它参数
self.label_name = {
'ants': 0, 'bees': 1}
self.data_info = get_info_list(data_path)
def __getitem__(self, index):
label, img_path = self.data_info[index]
pil_img = Image.open(img_path).convert('RGB') # 读数据
re_img = transforms.Resize((32, 32))(pil_img)
img = transforms.ToTensor()(re_img) # PIL转张量
return img, label
def __len__(self):
return len(self.data_info)
def get_info_list(list_path):
data_info = list()
with open(list_path, mode='r') as f:
lines = f.readlines()
for i in range(len(lines)):
img_label = int(lines[i].split(' ')[0])
img_path = lines[i].split(' ')[1]
data_info.append((img_label, img_path))
return data_info
if __name__ == '__main__':
train_list_path = os.path.join('old_data', 'train_set.txt')
val_list_path = os.path.join('old_data', 'val_set.txt')
test_list_path = os.path.join('old_data', 'test_set.txt')
train_set = DataSet(data_path=train_list_path)
val_set = DataSet(data_path=val_list_path)
test_set = DataSet(data_path=test_list_path)
Breakpoint debugging results:
DataLoader method
The DataLoader() method is to 可迭代
load the data provided on a given data set , that is, every time the model is iterated, a batch_size data is obtained from the DataLoader(). The function form and parameters are shown in the following code.
torch.utils.data.DataLoader(
dataset,
batch_size=1,
shuffle=False,
sampler=None,
batch_sampler=None,
num_workers=0,
collate_fn=None,
pin_memory=False,
drop_last=False,
timeout=0,
worker_init_fn=None,
multiprocessing_context=None)
dataset: An instance of a subclass inherited from the above abstract class Dataset, such as training data class, validation set class, etc.
batch_size: batch size, that is, the data size for one iteration, the default is 1
.
shuffle: Set epoch
whether the order of the samples in each is out of order, the default is False
. All training samples are input into the model once, called one epoch
. According to the size and batch_size
size of the training data , you can calculate epoch
how many iterations one needs to perform.
sampler: Define the strategy for extracting samples from the data set, that is, the way to generate the index, which can be in order or out of order. In the above example Dataset, reproducible __getitem__(self, index)
in index
that the the sampler
generated class. See this blog post for detailed explanation .
batch_sampler: returns one batch of data at a time index
, that is, the indices generated by the sampler are packaged and grouped to obtain the index of one batch after another.
num_workers: Whether to use multiple processes to read data, the default 0
is to read data in the main process.
collate_fn: Combine the data and tags of a batch.
drop_last: set True
time, if the data set size is not evenly divisible batch_size, then delete the last incomplete batche. Set asFalse
, And the size of the data set cannot be divisible by batch_size, so the last batch will be smaller.
- DataLoader use
The simple use of DataLoader can be realized by adding the following code to the Dataset instantiation. Read the training set, the batch size is 10, read randomly.
train_loader = DataLoader(dataset=train_set, batch_size=2, shuffle=True)
for i, data in enumerate(train_loader):
inputs, labels = data
Then set a breakpoint here, debug, single-step debugging, and observe the operation of DataLoader.
First, the initialization of the DataLoader is performed, such as DataLoader(object)
the __iter__(self)
method of entering the class in turn , selecting the single-process or multi-process DataLoader iterator, and then entering the single-process DataLoader iterator class _SingleProcessDataLoaderIter(_BaseDataLoaderIter)
initialization method by default __init__(self, loader)
, and then entering _BaseDataLoaderIter(object
the initialization method init (self ) of the class , loader) etc.
class DataLoader(object):
# ......
def __iter__(self):
if self.num_workers == 0:
return _SingleProcessDataLoaderIter(self)
else:
return _MultiProcessingDataLoaderIter(self)
class _SingleProcessDataLoaderIter(_BaseDataLoaderIter):
def __init__(self, loader):
super(_SingleProcessDataLoaderIter, self).__init__(loader)
assert self._timeout == 0
assert self._num_workers == 0
self._dataset_fetcher = _DatasetKind.create_fetcher(
self._dataset_kind, self._dataset, self._auto_collation, self._collate_fn, self._drop_last)
class _BaseDataLoaderIter(object):
def __init__(self, loader):
self._dataset = loader.dataset
self._dataset_kind = loader._dataset_kind
self._IterableDataset_len_called = loader._IterableDataset_len_called
self._auto_collation = loader._auto_collation
self._drop_last = loader.drop_last
self._index_sampler = loader._index_sampler
self._num_workers = loader.num_workers
self._pin_memory = loader.pin_memory and torch.cuda.is_available()
self._timeout = loader.timeout
self._collate_fn = loader.collate_fn
self._sampler_iter = iter(self._index_sampler)
self._base_seed = torch.empty((), dtype=torch.int64).random_().item()
self._num_yielded = 0
After the initialization of DataLoader is completed, it starts to read data. The first entry is the _BaseDataLoaderIter(object)
class def __next__(self)
method. Then jump to the _SingleProcessDataLoaderIter
class _next_data(self)
methods, and then enters the BatchSampler(Sampler)
sampler class __iter__(self)
method, here is the index value to generate a batch of all data, and placed in a list, as shown in FIG. Re-entering _SingleProcessDataLoaderIter
the class _next_data(self)
method, the generated index to _dataset_fetcher.fetch (index). Finally jump to our custom DataSet(Dataset):
class __getitem__(self, index)
, give it one index at a time, loop batch_size times. After the loop is completed, the default_collate(batch) method will be entered by default to integrate the data of this batch.
class BaseDataLoaderIter(object):
def __next__(self):
data = self._next_data()
# ......
class _SingleProcessDataLoaderIter(_BaseDataLoaderIter):
# ......
def _next_data(self):
index = self._next_index() # may raise StopIteration
data = self._dataset_fetcher.fetch(index) # may raise StopIteration
if self._pin_memory:
data = _utils.pin_memory.pin_memory(data)
return data
class BatchSampler(Sampler):
# ......
def __iter__(self):
batch = []
for idx in self.sampler:
batch.append(idx)
if len(batch) == self.batch_size:
yield batch
batch = []
if len(batch) > 0 and not self.drop_last:
yield batch
def default_collate(batch):
Summary: From the above analysis, we can know that in fact, data reading completely depends on our custom DataSet class. The DataLoader() method is mainly used for iteration. Each time we perform an iteration, BatchSampler will be called to generate an index of batch data. Return to the index, and then call the DataSet to __getitem__(self, index)
obtain the corresponding pictures and labels of the index in the data set. After the pictures and labels of a batch data are obtained, the package returns. In this way, a batch data read operation is realized.
reference
https://blog.csdn.net/qq_31622015/article/details/90573874
https://www.cnblogs.com/marsggbo/p/11308889.html
https://www.cnblogs.com/jiaxin359/p/7324077.html
https://ww1··w.cnblogs.com/jiaxin359/p/7324077.html