How to read from a high IO dataset in pytorch which grows from epoch to epoch

David Parks :

I use Tensorflow, but I'm writing documentation for users that will typically vary across deep learning frameworks.

When working with datasets that don't fit on the local filesystem (TB+) I sample data from a remote data store and write samples locally to a Tensorflow standardtfrecords format.

During the first epoch of training I will have only sampled a few values, therefore an epoch of local data is very small, I train on it. On epoch 2 I re-examine what data files have been produced by my sampling subprocesses (now more) and train on the expanded set of local data files for the next epoch. Repeat the process each epoch. In this way I build up a local cache of samples and can evict older samples as I fill up the local storage. The local samples cache grows at about the time the model needs the variance the most (towards the latter part of training).

In Python/Tensorflow it's crucial that I not deserialize the data in the Python training loop process because the Python GIL can't support the data transfer rates (300-600 MB/sec, the data is raw scientific uncompressible), and thus GPU performance suffers when the Python GIL can't service the training loop fast.

Writing the samples to a tfrecords file from subprocesses (python multiprocessing) allows tensorflow's native TFRecordsDataset to do deserialization outside of Python and thus we sidestep the Python GIL issues, and I can saturate a GPU with high IO data rates.

I would like to know how I would address this issue in Pytorch. I'm writing about the sampling strategy that's being used, and want to provide specific recommendations to users of both Tensorflow and PyTorch, but I don't know the PyTorch preprocessing ecosystem well enough to write with sufficient detail.

Side note: the only purely Python based solution to support these data transfer rates may come in Python 3.8 with System V shared memory and multiprocessing, but I haven't tried that yet as support for it isn't quite sufficient (soon it will be). Existing multiprocessing solutions aren't sufficient because they require deserialization in the training loop process and thus lock the GIL during deserialization at high IO rates.

bombs :

Actually, you can easily deserialize data in a subprocess by using torch.utils.data.DataLoader. By setting num_workers argument to 1 or a bigger value, you can spawn subprocesses with their own python interpreters and GILs.

loader = torch.utils.data.DataLoader(your_dataset, num_workers=n, **kwargs)
for epoch in range(epochs):
    for batch_idx, data in enumerate(loader):
         # loader in the main process does not claim GIL at this point

A Dataloader requires a torch.utils.data.Dataset to get data from. It may not be a trivial job to implement a proper subclass in your case. In case you need to recreate a Dataset instance for every epoch, you can do something like this.

for epcoh in range(epochs):
    dset = get_new_dataset()
    loader = torch.utils.data.DataLoader(dset, num_workers=n, **kwargs)
    for batch_idx, data in enumerate(loader):
        # Do training

or even better

dset = get_new_dataset()
loader = torch.utils.data.DataLoader(dset, num_workers=n, **kwargs)

for epcoh in range(epochs):
    last_batch_idx =  (len(dset)-1) // loader.batch_size
    for batch_idx, data in enumerate(loader):
        # Prepare next loader in advance to avoid blocking
        if batch_idx == last_batch_idx:
            dset = get_new_dataset()
            loader = torch.utils.data.DataLoader(dset, num_workers=n, **kwargs)
        # Do training

As a side note, please note that it's CPU bound operation that is affected by GIL in most cases, not I/O bound operation, i.e., threading will do for any purely I/O heavy operation and you don't even need subprocess. For more information please refer to this question and this wikipedia article.

Guess you like

Origin http://43.154.161.224:23101/article/api/json?id=8443&siteId=1