pytorch中使用Dataset和DataLoader创建自定义数据集 入门

介绍

pytorch中,我们可以使用torch.utils.data.DataLoadertorch.utils.data.Dataset加载数据集,具体来说,可以简单理解为Dataset是数据集,他提供数据与索引之间的映射,同时也要有标签。而DataLoader是将Dataset中的数据迭代提取出来,从而能够提供给模型。
所以,具体流程是,我们应该先按照要求先建立一个Dataset,之后再建立一个DataLoader,然后就可以用了。
pytorch中有很多现成的数据集,我们下载就可以使用。但是更多时候我们要建立自己的数据集,我也是入门,所以先建立一个带标签的图像数据集。

参考

建立Dataset

我们可以继承torch.utils.data.Dataset类,必须要重写__init__, __len__, 和 __getitem__这三个函数。其中 __len__能够返回我们数据集中的数据个数,__getitem__能够根据索引返回数据。

前提

我们有一个文件夹,里面有很多猫、狗和汽车的照片,此外有一个csv文件,里面是每张照片对应的类别,也就是标签。我们根据这个照片文件夹和csv文件,来建立我们的带标签数据集。

  • 对于图片文件夹:0——29张图片为猫,30——59张图片为狗,其他为汽车。
    图片文件夹
  • 对于标签csv文件,每一行中首先是图片名,然后是类别。其中0代表猫,1代表狗,2代表汽车。如下图:
    标签文件

具体代码

import os
from torchvision.io import read_image
import pandas as pd
from torch.utils.data import Dataset
import matplotlib.pyplot as plt
import numpy as np

class myImageDataset(Dataset):
    def __init__(self, img_dir, img_label_dir, transform=None):
        super().__init__()
        self.img_dir = img_dir
        self.img_labels = pd.read_csv(img_label_dir)  # 这是一个dataframe,0是文件名,1是类别
        self.transform = transform

    def __len__(self):
        return len(self.img_labels)  # 数据集长度
    
    def __getitem__(self, index):
        # 拼接得到图片文件路径
        # 例如img_dir为'D:/curriculum/2022learning/learnning_dataset/data/'
        # img_labels.iloc[index, 0]为5.jpg
        # 那么img_path为'D:/curriculum/2022learning/learnning_dataset/data/5.jpg'
        img_path = os.path.join(self.img_dir + self.img_labels.iloc[index, 0])
        image = read_image(img_path)  # tensor类型
        label = self.img_labels.iloc[index, 1]
        if self.transform is not None:
            image = self.transform(image)  # 对图片进行某些变换
        
        return image, label

代码中都有注释。

__init__()

类的初始化函数,其中img_dir为图片文件夹的根目录,img_label_dir为标签文件路径,transform为对数据项进行的变换。

__len__()

返回数据集长度。

__getitem__()

根据index,返回其在数据集中对应的数据和标签。

验证

通过如下代码,我们具体输出一张图片:

# 把图片对应的tensor调整维度,并显示
def tensorToimg(img_tensor):
    img = img_tensor.numpy()
    img = np.transpose(img_tensor, [1, 2, 0])
    plt.imshow(img)


label_dic = {
    
    0: 'cat', 1: 'dog', 2: 'car'}

label_path = 'D:/curriculum/2022learning/learnning_dataset/labels.csv'
img_root_path = 'D:/curriculum/2022learning/learnning_dataset/data/'
dataset = myImageDataset(img_root_path, label_path)

image, label = dataset.__getitem__(33)
print(image.shape)
print(label_dic[label])
tensorToimg(image)

结果
可以看到,数据集中,图片变为tensor,维度为[通道数,长,宽]。

DataLoader

之后就可以使用DataLoader对刚刚创建的数据集不断取出样本了。不再赘述。

dataloader = DataLoader(dataset, batch_size=5, shuffle=True)

这样,我们就建立了一个dataLoader。接下来我们输出一下看看:

for imgs, labels in dataloader:
    print(imgs.shape)
    print(labels)
    break

但是这里报错:stack expects each tensor to be equal size, but got [3, 268, 320] at entry 0 and [3, 480, 370] at ...,查询得知是数据集中图片大小不一,而这时Dataset中定义的参数transfom就派上了用场。我们让每张图片的大小都是224*224

from torch.utils.data import DataLoader
from torchvision import transforms

transform = transforms.Resize((224, 224))

dataset = myImageDataset(img_root_path, label_path, transform)
dataloader = DataLoader(dataset, batch_size=5, shuffle=True)
for imgs, labels in dataloader:
    print(imgs.shape)
    print(labels)
    break

结果为:

torch.Size([5, 3, 224, 224])
tensor([0, 2, 2, 2, 1])

由于batch_size是5,而每个图片的形状为[3, 224, 224],因此一个batch的数据形状为:[5, 3, 224, 224]

其他使用DataLoader的方法

for index, (imgs, labels) in enumerate(dataloader):
    print(index)
    print(imgs.shape)
    print(labels)
    break

结果为:

0
torch.Size([5, 3, 224, 224])
tensor([1, 2, 0, 1, 1])
imgs, label = next(iter(dataloader))
print(imgs.shape)
print(labels)

结果为:

torch.Size([5, 3, 224, 224])
tensor([1, 2, 0, 1, 1])

得到了一批的图片和对应的标签,我们就能将其输入到模型中,并使用标签和预测结果计算损失。

猜你喜欢

转载自blog.csdn.net/qq_43219379/article/details/123381194
今日推荐