DataLoader PyTorch 主要参数的含义

定义:

DataLoader类是一个用于从数据集(dataset)中加载数据,并以迭代器(iterator)的形式返回数据样本(data samples)的工具¹²。您给出的两个字典(dictionary)分别是训练集(train set)和测试集(test set)的数据加载参数,下面我会逐一解释它们的含义和默认值:

 

举例演示:

假设prefetch_factor=2,batch_size=4,persistent_workers=True,你的回答里面的每个worker每次获取多少数据呢?

假设您有一个名为train_dataset的数据集,它包含了100个数据样本,每个样本由一个特征向量和一个标签组成。您可以使用以下代码创建一个使用prefetch_factor=2,batch_size=4,persistent_workers=True的DataLoader对象:

train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=4, num_workers=8, prefetch_factor=2, persistent_workers=True)

 小提示:

pin_memory将数据样本保存在锁页内存,这里的内存指的是cpu的内存还是gpu的内存?

pin_memory参数是一个布尔值,表示是否将数据样本保存在锁页内存(pinned memory)中。这里的内存指的是CPU的内存,而不是GPU的内存12

 降低内存占用,推到的时候用到:

testloader_params = {
          'batch_size': 1,
          'shuffle': False,
          'num_workers': 0,
          'pin_memory': True,
          # 'prefetch_factor': 4,
          # 'persistent_workers': True,
          'drop_last': False
}

猜你喜欢

转载自blog.csdn.net/u010087338/article/details/132434265