Pytorch:Dataset总结

教科书 Pytorch入门与实践第五章pytorch-book/chapter5.ipynb at master · chenyuntc/pytorch-book · GitHub

1、TORCH.UTILS.DATA  官网地址torch.utils.data — PyTorch 1.11.0 documentation

在PyTorch中,数据加载可通过自定义的数据集对象。数据集对象被抽象为Dataset类,实现自定义的数据集需要继承Dataset,并实现两个Python魔法方法:

  • __getitem__:返回一条数据,或一个样本。obj[index]等价于obj.__getitem__(index)
  • __len__:返回样本的数量。len(obj)等价于obj.__len__()

其他的数据集类必须是torch.utils.data.Dataset的子类,比如说torchvision.ImageFolder. 
创建Dataset例子:

 

2. class torch.utils.data.sampler.Sampler(data_source)
参数: data_source (Dataset) – dataset to sample from

作用: 创建一个采样器, class torch.utils.data.sampler.Sampler是所有的Sampler的基类, 其中,iter(self)函数来获取一个迭代器,对数据集中元素的索引进行迭代,len(self)方法返回迭代器中包含元素的长度。

3. class torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None, num_workers=0, collate_fn=, pin_memory=False, drop_last=False, timeout=0, worker_init_fn=None)
参数:
* dataset (Dataset): 加载数据的数据集
* batch_size (int, optional): 每批加载多少个样本
* shuffle (bool, optional): 设置为“真”时,在每个epoch对数据打乱.(默认:False)
* sampler (Sampler, optional): 定义从数据集中提取样本的策略,返回一个样本
* batch_sampler (Sampler, optional): like sampler, but returns a batch of indices at a time 返回一批样本. 与atch_size, shuffle, sampler和 drop_last互斥.
* num_workers (int, optional): 用于加载数据的子进程数。0表示数据将在主进程中加载​​。(默认:0)
* collate_fn (callable, optional): 合并样本列表以形成一个 mini-batch.  # callable可调用对象
* pin_memory (bool, optional): 如果为 True, 数据加载器会将张量复制到 CUDA 固定内存中,然后再返回它们.pin memory中的数据转到GPU会快一些
* drop_last (bool, optional): 设定为 True 如果数据集大小不能被批量大小整除的时候, 将丢掉最后一个不完整的batch,(默认:False).
* timeout (numeric, optional): 如果为正值,则为从工作人员收集批次的超时值。应始终是非负的。(默认:0)
* worker_init_fn (callable, optional): If not None, this will be called on each worker subprocess with the worker id (an int in ``[0, num_workers - 1]``) as input, after seeding and before data loading. (default: None).
* generator (torch.Generatoroptional) – If not None, this RNG will be used by RandomSampler to generate random indexes and multiprocessing to generate base_seed for workers. (default: None)如果此选项不是None,RandomSampler 将使用RNG去生成随机的索引,在多进程中也会生成每个进程基准种子)
* prefetch_factor (intoptionalkeyword-only arg) – Number of samples loaded in advance by each worker. 2 means there will be a total of 2 * num_workers samples prefetched across all workers. (default: 2)对每个进程来说,需要提前装载的样本数量。当此值为2时,对所有的进程来数,需要提前获取的样本数量为2*num_workers,默认值为2)
persistent_workers (booloptional) – If True, the data loader will not shutdown the worker processes after a dataset has been consumed once. This allows to maintain the workers Dataset instances alive. (default: False)如果设置为True,对每个进程来数,当一个数据集被使用一次后,此进程并不会被关闭,这样就会保持进程中的数据集实例是活的。默认值是False。
代码例子:

from torch.utils.data import DataLoader

dataloader = DataLoader(dataset, batch_size=3, shuffle=True,
num_workers=0, drop_last=False)

dataiter = iter(dataloader)
imgs, labels = next(dataiter)
imgs.size() # batch_size, channel, height, weight

输出:torch.Size([3, 3, 224, 224])

4、Dataset Types:DataLoader中最重要的参数就是dataset,它决定了要装载的数据集。torch支持两种类型的数据集。
(1)map-style 类型。一个map-style类型是实现了__getitem__() 和__len__()协议的类,它代表了一个从索引/键值 到数据样本的映射。例如,对于一个通过dataset[idx]访问的数据集,可以读到第idx个图片,并从磁盘的文件中取到对应的标签。看下面例子吧

import os
from PIL import  Image
import numpy as np
from torchvision import transforms as T

transform = T.Compose([
    T.Resize(224), # 缩放图片(Image),保持长宽比不变,最短边为224像素
    T.CenterCrop(224), # 从图片中间切出224*224的图片
    T.ToTensor(), # 将图片(Image)转成Tensor,归一化至[0, 1]
    T.Normalize(mean=[.5, .5, .5], std=[.5, .5, .5]) # 标准化至[-1, 1],规定均值和标准差
])

class DogCat(data.Dataset):
    def __init__(self, root, transforms=None):
        imgs = os.listdir(root)
        self.imgs = [os.path.join(root, img) for img in imgs]
        self.transforms=transforms
        
    def __getitem__(self, index):
        img_path = self.imgs[index]
        label = 0 if 'dog' in img_path.split('/')[-1] else 1
        data = Image.open(img_path)
        if self.transforms:
            data = self.transforms(data)
        return data, label
    
    def __len__(self):
        return len(self.imgs)

dataset = DogCat('./data/dogcat/', transforms=transform)
img, label = dataset[0]
for img, label in dataset:
    print(img.size(), label)

输出:

torch.Size([3, 224, 224]) 0
torch.Size([3, 224, 224]) 1
torch.Size([3, 224, 224]) 1
torch.Size([3, 224, 224]) 0
torch.Size([3, 224, 224]) 1
torch.Size([3, 224, 224]) 0
torch.Size([3, 224, 224]) 0
torch.Size([3, 224, 224]) 1


(2)Iterable-style 类型。这种类型的父类是IterableDataset,并且实现了 __iter__() 协议,代表了在数据样本的迭代器。这种类型非常适合随机读取非常难或者不可能的情况,这种情况下批的大小取决于得到的数据。例如有这样一个数据集,可以调用iter(dataset),可以从数据库、远程服务器或者实时产生的日志获取样本流。

注意:当使用多进程加载Iterable-style 类型的DataLoader时,一份样本会被复制至所有的进程中,这份复制的样本将被差异配置以避免重复数据。请参阅IIterableDataset文档以如何实现它。

dataloader是一个可迭代的对象,意味着我们可以像使用迭代器一样使用它,例如:

for batch_datas, batch_labels in dataloader:
    train()

from torch.utils.data import DataLoader

dataloader = DataLoader(dataset, batch_size=3, shuffle=True, num_workers=0, drop_last=False)

dataiter = iter(dataloader)
batch_datas, batch_labesl = next(dataiter)

5、PyTorch中还单独提供了一个sampler模块,用来对数据进行采样。常用的有随机采样器:RandomSampler,当dataloader的shuffle参数为True时,系统会自动调用这个采样器,实现打乱数据。默认的是采用SequentialSampler,它会按顺序一个一个进行采样。这里介绍另外一个很有用的采样方法: WeightedRandomSampler,它会根据每个样本的权重选取数据,在样本比例不均衡的问题中,可用它来进行重采样。

语法:CLASStorch.utils.data.Sampler(data_source)
功能:所有sampler的基类。所有子类都必须重写__iter__()方法,提供一种在dataset elements的indices/keys上的迭代方法。选择性重写__len__()方法,返回这个迭代器的长度。
参数:

  • data_source (Dataset):dataset to sample from.

语法:CLASStorch.utils.data.SequentialSampler(data_source)
功能:顺序采样。

语法:CLASS torch.utils.data.RandomSampler
(data_sourcereplacement=Falsenum_samples=Nonegenerator=None)
功能:随机采样。

语法:CLASS torch.utils.data.WeightedRandomSampler
(weightsnum_samplesreplacement=Truegenerator=None)
功能:权重采样。构建WeightedRandomSampler时需提供两个参数:每个样本的权重weights、共选取的样本总数num_samples,以及一个可选参数replacement。权重越大的样本被选中的概率越大,待选取的样本数目一般小于全部的样本数目。replacement用于指定是否可以重复选取某一个样本,默认为True,即允许在一个epoch中重复采样某一个数据。如果设为False,则当某一类的样本被全部选取完,但其样本数目仍未达到num_samples时,sampler将不会再从该类中选择数据,此时可能导致weights参数失效。
代码:

dataset = DogCat('data/dogcat/', transforms=transform)

# 狗的图片被取出的概率是猫的概率的两倍
# 两类图片被取出的概率与weights的绝对大小无关,只和比值有关
weights = [2 if label == 1 else 1 for data, label in dataset]
weights

from torch.utils.data.sampler import  WeightedRandomSampler
sampler = WeightedRandomSampler(weights,\
                                num_samples=9,\
                                replacement=True)
dataloader = DataLoader(dataset,
                        batch_size=3,
                        sampler=sampler)
for datas, labels in dataloader:
    print(labels.tolist())

输出:
[1, 2, 2, 1, 2, 1, 1, 2]

[1, 0, 1]
[0, 0, 1]
[1, 0, 1]

猜你喜欢

转载自blog.csdn.net/qimo601/article/details/123658884