[Deep Learning] Custom dataset object mydataset | inherit torch.utils.data.Dataset class

Tip: After the article is written, the table of contents can be automatically generated. How to generate it can refer to the help document on the right


foreword

Similar to datasets.ImageFolder, there is also a very common method of customizing datasets in deep learning topics: inheriting the torch.utils.data.Dataset class.
You can refer to my previous blog: [Deep Learning] How to use datasets.ImageFolder

Is the object returned by datasets.ImageFolder the same as the object type generated by a custom dataset (such as MyDataset) that inherits torch.utils.data.Dataset? :
yes. They are all instances of the torch.utils.data.Dataset class, and both implement the __len__ and __getitem__ methods, which can be passed to torch.utils.data.DataLoader for data iteration and batch processing. Although their implementation methods are different, they all conform to the interface specification of torch.utils.data.Dataset, so they can be regarded as objects of the same type.

1. Example of customizing mydataset

For example, I want to read images from a specified folder to generate a dataset:

import os
from PIL import Image
from torch.utils.data import Dataset

class ImageDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.image_filenames = os.listdir(root_dir)
    
    def __len__(self):
        return len(self.image_filenames)
    
    def __getitem__(self, index):
        # 读取图像
        image_path = os.path.join(self.root_dir, self.image_filenames[index])
        image = Image.open(image_path).convert('RGB')
        
        # 对图像进行变换(如果有)
        if self.transform is not None:
            image = self.transform(image)
        
        return image

ImageDataset is the inherited dataset.
The most important parts are these three parts: constructor, two magic methods (len, getitem)

If I still preprocess the image, the code example is as follows:

import os
from PIL import Image
import torchvision.transforms as transforms
from torch.utils.data import Dataset

class ImageDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.image_filenames = os.listdir(root_dir)
    
    def __len__(self):
        return len(self.image_filenames)
    
    def __getitem__(self, index):
        # 读取图像
        image_path = os.path.join(self.root_dir, self.image_filenames[index])
        image = Image.open(image_path).convert('RGB')
        
        # 对图像进行变换(如果有)
        if self.transform is not None:
            image = self.transform(image)
        
        return image

# 定义变换函数
transform = transforms.Compose([
    transforms.Resize((224, 224)),  # 调整图像大小
    transforms.ToTensor(),         # 将图像转换为张量
    transforms.Normalize(          # 归一化图像
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    )
])

# 创建数据集实例
dataset = ImageDataset(root_dir='path/to/images', transform=transform)

# 获取第一张图像数据
image = dataset[0]

Since transforms.ToTensor() can directly convert a PIL.Image.Image object into a tensor, you can directly use transforms.ToTensor() for conversion here, so you don’t need to convert PIL.Image.Image to ndarray separately. Specifically, in transforms.ToTensor(), the PIL.Image.Image object is first converted to a numpy.ndarray object and then converted to a tensor.

2. What does torch.utils.data.Dataset look like

Let's open the dataset class function and take a look:

class Dataset(Generic[T_co]):
    r"""An abstract class representing a :class:`Dataset`.

    All datasets that represent a map from keys to data samples should subclass
    it. All subclasses should overwrite :meth:`__getitem__`, supporting fetching a
    data sample for a given key. Subclasses could also optionally overwrite
    :meth:`__len__`, which is expected to return the size of the dataset by many
    :class:`~torch.utils.data.Sampler` implementations and the default options
    of :class:`~torch.utils.data.DataLoader`.

    .. note::
      :class:`~torch.utils.data.DataLoader` by default constructs a index
      sampler that yields integral indices.  To make it work with a map-style
      dataset with non-integral indices/keys, a custom sampler must be provided.
    """
    functions: Dict[str, Callable] = {
    
    }

    def __getitem__(self, index) -> T_co:
        raise NotImplementedError

    def __add__(self, other: 'Dataset[T_co]') -> 'ConcatDataset[T_co]':
        return ConcatDataset([self, other])

    # No `def __len__(self)` default?
    # See NOTE [ Lack of Default `__len__` in Python Abstract Base Classes ]
    # in pytorch/torch/utils/data/sampler.py

    def __getattr__(self, attribute_name):
        if attribute_name in Dataset.functions:
            function = functools.partial(Dataset.functions[attribute_name], self)
            return function
        else:
            raise AttributeError

    @classmethod
    def register_function(cls, function_name, function):
        cls.functions[function_name] = function

    @classmethod
    def register_datapipe_as_function(cls, function_name, cls_to_register, enable_df_api_tracing=False):
        if function_name in cls.functions:
            raise Exception("Unable to add DataPipe function name {} as it is already taken".format(function_name))

        def class_function(cls, enable_df_api_tracing, source_dp, *args, **kwargs):
            result_pipe = cls(source_dp, *args, **kwargs)
            if isinstance(result_pipe, Dataset):
                if enable_df_api_tracing or isinstance(source_dp, DFIterDataPipe):
                    if function_name not in UNTRACABLE_DATAFRAME_PIPES:
                        result_pipe = result_pipe.trace_as_dataframe()

            return result_pipe

        function = functools.partial(class_function, cls_to_register, enable_df_api_tracing)
        cls.functions[function_name] = function

The interpretation of chatgpt is: the definition of the Dataset class, which is an abstract class, and all datasets representing keys to data samples should inherit from it. All subclasses should override the getitem method to support getting data samples for a given key. Subclasses can also optionally override the len method, which is used by many torch.utils.data.Sampler implementations and the default option of torch.utils.data.DataLoader to return the size of the dataset. If the key of the dataset is not an integer type, you need to provide a custom sampler (sampler) to make it compatible with torch.utils.data.DataLoader. In addition, the Dataset class also provides some methods and attributes, such as add method, getattr method and so on.

raise NotImplementedError indicates that the method has not been implemented and needs to be implemented in the subclass. In Python, using the NotImplementedError exception can conveniently prompt developers that the method has not been implemented, which is also a standardized implementation. Of course, you can also implement these two methods directly in the subclass instead of using NotImplementedError.

3. Summary of some used inherited dataset classes

3.1. In the image denoising task, use patch to divide a single image into multiple subimages for training

import torch
from torch.utils.data import Dataset
from torchvision import transforms
from PIL import Image
import os

class PatchDataset(Dataset):
    def __init__(self, noisy_image_folder, clean_image_folder, patch_size=64):
        self.noisy_image_folder = noisy_image_folder
        self.clean_image_folder = clean_image_folder
        self.patch_size = patch_size
        self.transform = transforms.Compose([
            transforms.Resize(patch_size + 16),
            transforms.RandomCrop(patch_size),
            transforms.ToTensor(),
            transforms.Normalize(
                mean=[0.5],
                std=[0.5]
            )
        ])
        self.noisy_image_paths = [os.path.join(noisy_image_folder, x) for x in os.listdir(noisy_image_folder)]
        self.clean_image_paths = [os.path.join(clean_image_folder, x) for x in os.listdir(clean_image_folder)]
    
    def __len__(self):
        return len(self.noisy_image_paths)
    
    def __getitem__(self, idx):
        # 读取图像
        noisy_image = Image.open(self.noisy_image_paths[idx]).convert('L')
        clean_image = Image.open(self.clean_image_paths[idx]).convert('L')
        # 对图像进行 patch 操作
        patches = []
        for i in range(4):  # 每个图像分割成 4 个 patch
            for j in range(4):
                x = j * self.patch_size
                y = i * self.patch_size
                noisy_patch = noisy_image.crop((x, y, x + self.patch_size, y + self.patch_size))
                clean_patch = clean_image.crop((x, y, x + self.patch_size, y + self.patch_size))
                noisy_patch = self.transform(noisy_patch)
                clean_patch = self.transform(clean_patch)
                patches.append((noisy_patch, clean_patch))
        return patches

In this example, we define a custom dataset class called PatchDataset that inherits from PyTorch's Dataset class. In the init function, we pass in the noisy image folder path noisy_image_folder, the noise-free image folder path clean_image_folder and the size patch_size of the patch, and define the transformation function self.transform. In the getitem function, we read the noisy image and the noise-free image and perform patch operation on them, and form a list of 16 patches obtained and return, each element is a patch containing noise and the corresponding noise-free patch of tuples.
Note that in this example we normalize both the noisy and non-noisy patches. This is because in the image denoising task, we need to input noisy images into the model for training, and at the same time need to use non-noisy images as labels for supervised learning. Therefore, performing the same normalization operation on noisy and noise-free images can simplify the code and improve the training effect.

3.2. In the HDR image reconstruction task, the transform in the dataset class should be like a horse (.hdr.exr file and jpg are opened differently)

import torch
from torch.utils.data import Dataset
from torchvision import transforms
import cv2
import os

class PatchDataset(Dataset):
    def __init__(self, ldr_image_folder, hdr_image_folder, patch_size=64):
        self.ldr_image_folder = ldr_image_folder
        self.hdr_image_folder = hdr_image_folder
        self.patch_size = patch_size
        self.transform = transforms.Compose([
            transforms.Resize(patch_size + 16),
            transforms.RandomCrop(patch_size),
            transforms.ToTensor(),
            transforms.Normalize(
                mean=[0.5, 0.5, 0.5],
                std=[0.5, 0.5, 0.5]
            )
        ])
        self.ldr_image_paths = [os.path.join(ldr_image_folder, x) for x in os.listdir(ldr_image_folder)]
        self.hdr_image_paths = [os.path.join(hdr_image_folder, x) for x in os.listdir(hdr_image_folder)]
    
    def __len__(self):
        return len(self.ldr_image_paths)
    
    def read_hdr_image(self, path):
        # 读取 HDR 图像
        hdr_image = cv2.imread(path, cv2.IMREAD_ANYDEPTH | cv2.IMREAD_COLOR)
        # 将像素值恢复到原始范围
        hdr_image = hdr_image / 65535.0 * 100.0  # 假设原始范围为 [0, 100]
        return hdr_image
    
    def __getitem__(self, idx):
        # 读取图像
        ldr_image = cv2.imread(self.ldr_image_paths[idx], cv2.IMREAD_COLOR)
        hdr_image = self.read_hdr_image(self.hdr_image_paths[idx])
        # 对图像进行 patch 操作
        patches = []
        for i in range(4):  # 每个图像分割成 4 个 patch
            for j in range(4):
                x = j * self.patch_size
                y = i * self.patch_size
                ldr_patch = ldr_image[y:y+self.patch_size, x:x+self.patch_size, :]
                hdr_patch = hdr_image[y:y+self.patch_size, x:x+self.patch_size, :]
                ldr_patch = self.transform(ldr_patch)
                hdr_patch = self.transform(hdr_patch)
                patches.append((ldr_patch, hdr_patch))
        return patches

We first define a read_hdr_image() function that uses OpenCV's imread() function to read the HDR image and restore the pixel values ​​to their original range. In the getitem function, we read the LDR image and HDR image and perform patch operation on them, and form a list of 16 patches obtained and return it, where each element is a tuple containing the LDR patch and the corresponding HDR patch.

Because it is the code generated by chatgpt, whether it is correct or not needs to be verified on the computer: the
pixel value of the HDR image in .hdr or .exr format is usually a floating point number (for example, in OpenEXR, the data type of the pixel value is FLOAT), and usually not in 0 to 255, but in a larger range. The exact range depends on the acquisition device and the parameters used during image processing.
When using OpenCV to read an HDR image, it scales the pixel values ​​to the range of 0 to 255, so we need to manually restore the pixel values ​​to the original range. In this example, we assume the original range is [0, 100], so we divide the original floating point pixel value by 65535 (ie 2**16-1, since pixel values ​​are stored as 16-bit integers in OpenCV) and multiply by 100 to get the pixel values ​​back to their original range.

Guess you like

Origin blog.csdn.net/weixin_46274756/article/details/130168398