Pytorch学习(一) --- 数据加载之Dataset类和DataLoader类

在进行深度学习任务时,一个完整的baseline通常分为以下几个部分:

  1. 定义模型。这里需要构建网络模型,后面用这个模型去训练。
  2. 定义数据增强。这里主要是在数据量少的情况下,对数据进行一些增强,比如平移,翻转,裁剪等操作,以提高模型的泛化能力(这一步不是必须的)。
  3. 定义数据加载。这里定义数据加载器,使得模型训练时模型能源源不断地获取数据进行训练。对于Pytorch而言,数据记载主要需要用到DatasetDataLoader这两个类。
  4. 模型训练。这里首先需要定义模型的一些参数配置,优化器,损失函数定义之类的,至此我们就可以进行训练了。
  5. 模型测试

本文主要是对Pytorch中定义数据加载的方法做一个学习。

Dataset

Dataset是Pytorch中的一个数据读取类,它已经包含了很多常见的数据集,如下:


torchvision.datasets中包含了以下数据集

  • MNIST
  • COCO(用于图像标注和目标检测)(Captioning and Detection)
  • LSUN Classification
  • ImageFolder
  • Imagenet-12
  • CIFAR10 and CIFAR100
  • STL10

我们可以直接使用这个Dataset类里面的数据集,示例如下:

dset.MNIST(root, train=True, transform=None, target_transform=None, download=False)

其中

  • root : processed/training.pt 和 processed/test.pt 的主目录
  • train : True = 训练集, False = 测试集
  • download : True = 从互联网上下载数据集,并把数据集放在root目录下. 如果数据集之前下载过,将处理过的数据(minist.py中有相关函数)放在processed文件夹下。
  • transform表示数据是否需要做预处理,默认为None/不要。
    上述就完成了MNIST数据集的加载定义。

Dataset的定义如下:

class Dataset(object):
	def __init__(self):
		...
		
	def __getitem__(self, index):
		return ...
	
	def __len__(self):
		return ...

主要包含两个方法:

__getitem__()

__getitem__函数的作用是根据索引index遍历数据,一般返回image的Tensor形式和对应标注。当然也可以多返回一些其它信息,这个根据需求而定。

 __len__()

__len__函数的作用是返回数据集的长度。

在我们训练自己的数据时,需要继承它,并需要重写__getitem__()__len__()这两个方法。
示例如下:

class CarDataset(Dataset):
    def __init__(self, img_df, transform=None):
        self.img_df = img_df
        if transform is not None:
            self.transform = transform
        else:
            self.transform = None
    
    def __getitem__(self, index):
        # start_time = time.time()
        # img = Image.open(self.img_df.iloc[index]['index']).convert('RGB')
        img = cv2.imread(self.img_df.iloc[index]['filename'])
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        
        if self.transform is not None:
            img = self.transform(image=img)
        return img['image'], torch.from_numpy(np.array(self.img_df.iloc[index]['label']))
    
    def __len__(self):
        return len(self.img_df)
  • getitem 函数接收一个index,然后返回图片数据和标签,这个index通常指的是一个list的index,这个list的每个元素就包含了图片数据的路径和标签信息
  • getitem 函数返回的图像必须是tensor,所以我们还需要对读到的img做一个Totensor的转换,这个操作通常会写在transform函数里。如下:
train_transform = Compose([     
                                    Resize(288,352),
                                    HorizontalFlip(),
                                    OneOf([
                                        RandomContrast(),
                                        RandomGamma(),
                                        RandomBrightness(),
                                    ], p=0.3),
                                    OneOf([
                                        CLAHE(p=0.5),
                                        GaussianBlur(3, p=0.3),
                                        IAASharpen(alpha=(0.2,0.3), p=0.3),
                                    ], p=1),
                                    Normalize(
                                        mean=[0.485, 0.456, 0.406],
                                        std=[0.229, 0.224, 0.225],
                                    ),
                                    RandomCrop(256, 320),
                                    ToTensor()
                                ])

最后一个是ToTensor()。

DataLoader

class torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, num_workers=0, collate_fn=<function default_collate>, pin_memory=False, drop_last=False)

构建一个可迭代的数据装载器,可以理解为在训练过程中,DataLoader将自定义的Dataset根据batch size大小、是否shuffle等封装成一个又一个batch大小的Tensor,数据给模型进行训练测试。

即在DataLoder中,会触发Mydataset中的getiterm函数读取一张图片的数据和标签,并拼接成一个batch返回,作为模型真正的输入。

参数表如下:

  • dataset (Dataset) – 加载数据的数据集。
  • batch_size (int, optional) – 每个batch加载多少个样本(默认: 1)。
  • shuffle (bool, optional) – 设置为True时会在每个epoch重新打乱数据(默认: False).
  • sampler (Sampler, optional) – 定义从数据集中提取样本的策略。如果指定,则忽略shuffle参数。
  • num_workers (int, optional) – 用多少个子进程加载数据。0表示数据将在主进程中加载(默认: 0)
  • collate_fn (callable, optional) – 将一个list的sample组成一个mini-batch的函数
  • pin_memory (bool, optional) – 如果设置为True,那么data loader将会在返回它们之前,将tensors拷贝到CUDA中的固定内存(CUDA pinned memory)中
  • drop_last (bool, optional) – 如果数据集大小不能被batch size整除,则设置为True后可删除最后一个不完整的batch。如果设为False并且数据集的大小不能被batch size整除,则最后一个batch将更小。(默认: False)

常用的是

  • dataset
  • batch_size
  • shuffle
  • num_workers

自定义示例:

train_loader = torch.utils.data.DataLoader(
            CarDataset(train_label,
                        train_transform,
            ), batch_size=batch_size, shuffle=True, num_workers=work_num, pin_memory=True
        )
参考

https://blog.csdn.net/u014380165/article/details/79058479?depth_1-utm_source=distribute.pc_relevant.none-task&utm_source=distribute.pc_relevant.none-task
https://pytorch-cn.readthedocs.io/zh/latest/torchvision/torchvision-datasets/
https://blog.csdn.net/g11d111/article/details/81504637?depth_1-utm_source=distribute.pc_relevant.none-task&utm_source=distribute.pc_relevant.none-task

原创文章 96 获赞 24 访问量 3万+

猜你喜欢

转载自blog.csdn.net/c2250645962/article/details/105198255