2 torch.utils.data.DataLoader()

torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None, num_workers=0, collate_fn=None, pin_memory=False, drop_last=False, timeout=0, worker_init_fn=None, multiprocessing_context=None, generator=None, *, prefetch_factor=2, persistent_workers=False)

2.1 作用

对数据进行预处理,输出可以用来进行训练的数据张量。

2.2 参数

dataset:传入torch.utils.data.Dataset类的一个实例。

batch_size:mini batch的大小,应为int型。

shuffle:True的话是指数据是否会被随机打乱,默认False。

sampler:自定义的采样器(shuffle=True时会构建默认的采样器,如果想使用自定义的方法需要构造一个torch.utils.data.Sampler的实例来进行采样,并设置shuffle=False,将实例作为参数传入),返回一个数据数据的下标索引。

batch_sampler:和sampler类似,不过batch_sampler返回的是一个mini batch的数据索引,而sampler返回的是下标索引。

num_workers:dataloader使用的进程数目,应为int型。

扫描二维码关注公众号,回复: 15235136 查看本文章

collect_fn:传入一个自定义的函数,定义如何把一批Dataset的实例转换为包含迷你批次的数据张量,例如这里是YoloV3里的collect_fn:

def yolo_dataset_collate(batch):
    images = []
    bboxes = []
    for img, box in batch:
        images.append(img)
        bboxes.append(box)
    images = np.array(images)
    return images, bboxes

pin_memory:True的话会把数据转移到和GPU内存相关联的CPU内存中,从而能够加快GPU载入数据的速度。

drop_last:设置为True的话,当batch_size不能整除dataset里的数据总数时,会将最后一个batch抛弃,也就是说每一个batch都严格等于batch_size。

timeout:值如果大于零,就会决定在多进程情况下对载入数据的等待时间。

worker_init_fn:决定了每个子进程开始时运行的函数,这个函数运行在随机种子设置以后、载入数据之前。

multiprocessing_context:官方文档暂时未给出。

generator:如果不是none,这个随机数发生器将用来生成随机索引和多进程。(官方文档翻译过来的)

prefetch_factor:每个进程开始之前预加载的sample数。

persistent_workers:如果设置为True,dataloader不会在数据集被使用一次后关闭工作进程。

2.3 使用方法

在创建好了一个DataLoader的实例之后,需要利用for循环来读取批量数据,在循环中进行每个batch的训练:

gen = DataLoader(train_dataset, shuffle=True, batch_size=batch_size, num_workers = num_workers, pin_memory=True, drop_last=True, collate_fn=yolo_dataset_collate)
# 第一个参数train_dataset 为Dataset类的一个实例

for iteration, batch in enumerate(gen):
    # for循环里为每个batch的训练内容

Dataloader每次循环时,先使用Dataset里的__getitem__方法获取batchsize个数据(也就是上面代码for循环里的batch),再使用collect_fn函数对batch做一些自定义的操作。

参考《深入浅出PyTorch》张校捷

猜你喜欢

转载自blog.csdn.net/fuss1207/article/details/123044790
今日推荐