Dataloader reading of pytorch variable length data

  References:

  https://pytorch.org/docs/stable/data.html#dataloader-collate-fn

  https://blog.csdn.net/anshiquanshu/article/details/112868740

  When using the Pytorch deep learning framework, the dataset and dataloader must be unavoidable. The latter relies on the former and provides solutions for efficiently loading data (multithreading, batch training, etc.).

  Taking RGB images as an example, the data shape from dataset is (3, H, W), while the data shape from dataloader is (batch_size, 3, H, W). Obviously, there is one more dimension, the batch dimension. This is obviously dataloader "stacking" the data. In fact, dataloader has a parameter called collate_fn, and its default value is None, that is, when you use dataloader and do not specify collate_fn, pytorch actually calls the default collate_fn function to "stack" the data before here you are.

  However, when your data is of variable length, it cannot successfully stack the data. For example, I encountered the following error:

  RuntimeError: stack expects each tensor to be equal size, but got [2, 4] at entry 0 and [5, 4] at entry 1

  One data length is 2, and the other data length is 5. Obviously, it cannot be directly stacked? At this time, when faced with variable-length data, you need to customize collate_fn to fill it. For example, there is such a passage in the pytorch documentation:

  A custom collate_fn can be used to customize collation, e.g., padding sequential data to max length of a batch.

  So, how to customize a collate_fn? What is the input and output of this collate_fn? Let's take a look at this example:

def padding_collate_fn(data_batch):
    batch_bbox_list = [item['bbox'] for item in data_batch]
    batch_label_list = [item['label'] for item in data_batch]
    batch_filename_list = [item['filename'] for item in data_batch]
    
    padding_bbox = pad_sequence(batch_bbox_list, batch_first=True, padding_value=0)
    padding_label = pad_sequence(batch_bbox_list, batch_first=True, padding_value=5)
    
    result = dict()
    result["bbox"] = padding_bbox
    result["label"] = padding_label
    result["filename"] = batch_filename_list
    
    return result

  First of all, my original dataset output is a dictionary. The above code is to take out the values ​​in the dictionary and stack them, and finally return them in a large dictionary. Among them, the pad_sequence function is in the torch.nn.utils.rnn package, which is very useful.

  In fact, batch is a list composed of your dataset[index] ~ dataset[index + batch_size]. Once you know this, the problem will be solved.

Guess you like

Origin blog.csdn.net/weixin_43590796/article/details/129253576