【深度学习】自定义数据集对象mydataset |继承torch.utils.data.Dataset类

提示:文章写完后,目录可以自动生成,如何生成可参考右边的帮助文档


前言

与datasets.ImageFolder类似,深度学习课题中还有一种很常用的自定义数据集的方法:继承torch.utils.data.Dataset类
可以参考我之前的博客:【深度学习】datasets.ImageFolder 使用方法

datasets.ImageFolder返回的对象和继承torch.utils.data.Dataset的自定义数据集(如MyDataset)生成的对象类型是一样的的吗?:
是的。它们都是torch.utils.data.Dataset类的实例,都实现了__len__和__getitem__方法,可以被传递给torch.utils.data.DataLoader用于数据的迭代和批处理等操作。虽然它们的实现方式不同,但是它们都符合了torch.utils.data.Dataset的接口规范,因此可以被视为同一类型的对象。

一、自定义mydataset的例子

比如我要从指定文件夹里读取图片生成数据集:

import os
from PIL import Image
from torch.utils.data import Dataset

class ImageDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.image_filenames = os.listdir(root_dir)
    
    def __len__(self):
        return len(self.image_filenames)
    
    def __getitem__(self, index):
        # 读取图像
        image_path = os.path.join(self.root_dir, self.image_filenames[index])
        image = Image.open(image_path).convert('RGB')
        
        # 对图像进行变换(如果有)
        if self.transform is not None:
            image = self.transform(image)
        
        return image

ImageDataset就是继承的dataset
最重要的就是这三部分:构造函数,两个魔术方法(len,getitem)

如果我对图像还有预处理的话,代码举例如下:

import os
from PIL import Image
import torchvision.transforms as transforms
from torch.utils.data import Dataset

class ImageDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.image_filenames = os.listdir(root_dir)
    
    def __len__(self):
        return len(self.image_filenames)
    
    def __getitem__(self, index):
        # 读取图像
        image_path = os.path.join(self.root_dir, self.image_filenames[index])
        image = Image.open(image_path).convert('RGB')
        
        # 对图像进行变换(如果有)
        if self.transform is not None:
            image = self.transform(image)
        
        return image

# 定义变换函数
transform = transforms.Compose([
    transforms.Resize((224, 224)),  # 调整图像大小
    transforms.ToTensor(),         # 将图像转换为张量
    transforms.Normalize(          # 归一化图像
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    )
])

# 创建数据集实例
dataset = ImageDataset(root_dir='path/to/images', transform=transform)

# 获取第一张图像数据
image = dataset[0]

由于 transforms.ToTensor() 能够将 PIL.Image.Image 对象直接转换为张量,因此在这里可以直接使用 transforms.ToTensor() 进行转换,就不用再把PIL.Image.Image单独把转为ndarray了。具体来说,在 transforms.ToTensor() 中,会先将 PIL.Image.Image 对象转换为 numpy.ndarray 对象,然后再将其转换为张量。

二、torch.utils.data.Dataset长啥样

我们打开dataset类函数进去看看:

class Dataset(Generic[T_co]):
    r"""An abstract class representing a :class:`Dataset`.

    All datasets that represent a map from keys to data samples should subclass
    it. All subclasses should overwrite :meth:`__getitem__`, supporting fetching a
    data sample for a given key. Subclasses could also optionally overwrite
    :meth:`__len__`, which is expected to return the size of the dataset by many
    :class:`~torch.utils.data.Sampler` implementations and the default options
    of :class:`~torch.utils.data.DataLoader`.

    .. note::
      :class:`~torch.utils.data.DataLoader` by default constructs a index
      sampler that yields integral indices.  To make it work with a map-style
      dataset with non-integral indices/keys, a custom sampler must be provided.
    """
    functions: Dict[str, Callable] = {
    
    }

    def __getitem__(self, index) -> T_co:
        raise NotImplementedError

    def __add__(self, other: 'Dataset[T_co]') -> 'ConcatDataset[T_co]':
        return ConcatDataset([self, other])

    # No `def __len__(self)` default?
    # See NOTE [ Lack of Default `__len__` in Python Abstract Base Classes ]
    # in pytorch/torch/utils/data/sampler.py

    def __getattr__(self, attribute_name):
        if attribute_name in Dataset.functions:
            function = functools.partial(Dataset.functions[attribute_name], self)
            return function
        else:
            raise AttributeError

    @classmethod
    def register_function(cls, function_name, function):
        cls.functions[function_name] = function

    @classmethod
    def register_datapipe_as_function(cls, function_name, cls_to_register, enable_df_api_tracing=False):
        if function_name in cls.functions:
            raise Exception("Unable to add DataPipe function name {} as it is already taken".format(function_name))

        def class_function(cls, enable_df_api_tracing, source_dp, *args, **kwargs):
            result_pipe = cls(source_dp, *args, **kwargs)
            if isinstance(result_pipe, Dataset):
                if enable_df_api_tracing or isinstance(source_dp, DFIterDataPipe):
                    if function_name not in UNTRACABLE_DATAFRAME_PIPES:
                        result_pipe = result_pipe.trace_as_dataframe()

            return result_pipe

        function = functools.partial(class_function, cls_to_register, enable_df_api_tracing)
        cls.functions[function_name] = function

chatgpt的解读为:Dataset 类的定义,它是一个抽象类,所有表示从键到数据样本的数据集都应该继承它。所有的子类都应该覆盖 getitem 方法,支持根据给定的键获取数据样本。子类还可以选择性地覆盖 len 方法,它被许多 torch.utils.data.Sampler 实现和 torch.utils.data.DataLoader 的默认选项所使用,用于返回数据集的大小。如果数据集的键不是整数类型,需要提供一个自定义的采样器(sampler)来使其与 torch.utils.data.DataLoader 兼容。此外,Dataset 类还提供了一些方法和属性,如 add 方法、getattr 方法等。

raise NotImplementedError 表示该方法还没有被实现,需要在子类中进行实现。在 Python 中,使用 NotImplementedError 异常可以方便地提示开发者该方法还未被实现,这也是一种规范的实现方式。当然,你也可以直接在子类中实现这两个方法,而不是使用 NotImplementedError。

三.一些使用过的继承dataset类总结

3.1.在图像去噪任务中,使用patch将单张图片分割为多个子图训练

import torch
from torch.utils.data import Dataset
from torchvision import transforms
from PIL import Image
import os

class PatchDataset(Dataset):
    def __init__(self, noisy_image_folder, clean_image_folder, patch_size=64):
        self.noisy_image_folder = noisy_image_folder
        self.clean_image_folder = clean_image_folder
        self.patch_size = patch_size
        self.transform = transforms.Compose([
            transforms.Resize(patch_size + 16),
            transforms.RandomCrop(patch_size),
            transforms.ToTensor(),
            transforms.Normalize(
                mean=[0.5],
                std=[0.5]
            )
        ])
        self.noisy_image_paths = [os.path.join(noisy_image_folder, x) for x in os.listdir(noisy_image_folder)]
        self.clean_image_paths = [os.path.join(clean_image_folder, x) for x in os.listdir(clean_image_folder)]
    
    def __len__(self):
        return len(self.noisy_image_paths)
    
    def __getitem__(self, idx):
        # 读取图像
        noisy_image = Image.open(self.noisy_image_paths[idx]).convert('L')
        clean_image = Image.open(self.clean_image_paths[idx]).convert('L')
        # 对图像进行 patch 操作
        patches = []
        for i in range(4):  # 每个图像分割成 4 个 patch
            for j in range(4):
                x = j * self.patch_size
                y = i * self.patch_size
                noisy_patch = noisy_image.crop((x, y, x + self.patch_size, y + self.patch_size))
                clean_patch = clean_image.crop((x, y, x + self.patch_size, y + self.patch_size))
                noisy_patch = self.transform(noisy_patch)
                clean_patch = self.transform(clean_patch)
                patches.append((noisy_patch, clean_patch))
        return patches

在这个示例中,我们定义了一个名为 PatchDataset 的自定义数据集类,它继承自 PyTorch 的 Dataset 类。在 init 函数中,我们传入了有噪点图像文件夹路径 noisy_image_folder、无噪点图像文件夹路径 clean_image_folder 和 patch 的大小 patch_size,并定义了变换函数 self.transform。在 getitem 函数中,我们读取有噪点图像和无噪点图像并对它们进行 patch 操作,将得到的 16 个 patch 组成一个列表并返回,其中每个元素是一个包含有噪点 patch 与对应的无噪点 patch 的元组。
需要注意的是,在这个示例中,我们将有噪点 patch 和无噪点 patch 都进行了归一化。这是因为在图像去噪任务中,我们需要将有噪点图像输入模型进行训练,同时需要使用无噪点图像作为标签进行监督学习。因此,对有噪点图像和无噪点图像进行相同的归一化操作可以简化代码并提高训练效果。

3.2.在HDR图像重建任务中,dataset类中的transform应该是神马样的(.hdr.exr文件和jpg打开方式不太一样)

import torch
from torch.utils.data import Dataset
from torchvision import transforms
import cv2
import os

class PatchDataset(Dataset):
    def __init__(self, ldr_image_folder, hdr_image_folder, patch_size=64):
        self.ldr_image_folder = ldr_image_folder
        self.hdr_image_folder = hdr_image_folder
        self.patch_size = patch_size
        self.transform = transforms.Compose([
            transforms.Resize(patch_size + 16),
            transforms.RandomCrop(patch_size),
            transforms.ToTensor(),
            transforms.Normalize(
                mean=[0.5, 0.5, 0.5],
                std=[0.5, 0.5, 0.5]
            )
        ])
        self.ldr_image_paths = [os.path.join(ldr_image_folder, x) for x in os.listdir(ldr_image_folder)]
        self.hdr_image_paths = [os.path.join(hdr_image_folder, x) for x in os.listdir(hdr_image_folder)]
    
    def __len__(self):
        return len(self.ldr_image_paths)
    
    def read_hdr_image(self, path):
        # 读取 HDR 图像
        hdr_image = cv2.imread(path, cv2.IMREAD_ANYDEPTH | cv2.IMREAD_COLOR)
        # 将像素值恢复到原始范围
        hdr_image = hdr_image / 65535.0 * 100.0  # 假设原始范围为 [0, 100]
        return hdr_image
    
    def __getitem__(self, idx):
        # 读取图像
        ldr_image = cv2.imread(self.ldr_image_paths[idx], cv2.IMREAD_COLOR)
        hdr_image = self.read_hdr_image(self.hdr_image_paths[idx])
        # 对图像进行 patch 操作
        patches = []
        for i in range(4):  # 每个图像分割成 4 个 patch
            for j in range(4):
                x = j * self.patch_size
                y = i * self.patch_size
                ldr_patch = ldr_image[y:y+self.patch_size, x:x+self.patch_size, :]
                hdr_patch = hdr_image[y:y+self.patch_size, x:x+self.patch_size, :]
                ldr_patch = self.transform(ldr_patch)
                hdr_patch = self.transform(hdr_patch)
                patches.append((ldr_patch, hdr_patch))
        return patches

我们首先定义了一个 read_hdr_image() 函数,它使用 OpenCV 的 imread() 函数来读取 HDR 图像,并将像素值恢复到原始范围。在 getitem 函数中,我们读取 LDR 图像和 HDR 图像并对它们进行 patch 操作,将得到的 16 个 patch 组成一个列表并返回,其中每个元素是一个包含 LDR patch 与对应的 HDR patch 的元组。

因为是chatgpt生成的代码 ,是否对错需要上电脑验证:
.hdr 或 .exr 格式的 HDR 图像的像素值通常是浮点数(比如在 OpenEXR 中,像素值的数据类型为 FLOAT),而且通常不是在 0 到 255 的范围内,而是在一个更大的范围内。具体的范围取决于采集设备和图像处理过程中所使用的参数。
在使用 OpenCV 来读取 HDR 图像时,它会将像素值缩放到 0 到 255 的范围内,因此我们需要手动将像素值恢复到原始范围。在这个示例中,我们假设原始范围为 [0, 100],因此我们将原始的浮点数像素值除以 65535(即 2**16-1,因为像素值在 OpenCV 中被存储为 16 位整数)并乘以 100,以将像素值恢复到原始范围。

猜你喜欢

转载自blog.csdn.net/weixin_46274756/article/details/130168398