Tip: After the article is written, the table of contents can be automatically generated. How to generate it can refer to the help document on the right
Article directory
foreword
Similar to datasets.ImageFolder, there is also a very common method of customizing datasets in deep learning topics: inheriting the torch.utils.data.Dataset class.
You can refer to my previous blog: [Deep Learning] How to use datasets.ImageFolder
Is the object returned by datasets.ImageFolder the same as the object type generated by a custom dataset (such as MyDataset) that inherits torch.utils.data.Dataset? :
yes. They are all instances of the torch.utils.data.Dataset class, and both implement the __len__ and __getitem__ methods, which can be passed to torch.utils.data.DataLoader for data iteration and batch processing. Although their implementation methods are different, they all conform to the interface specification of torch.utils.data.Dataset, so they can be regarded as objects of the same type.
1. Example of customizing mydataset
For example, I want to read images from a specified folder to generate a dataset:
import os
from PIL import Image
from torch.utils.data import Dataset
class ImageDataset(Dataset):
def __init__(self, root_dir, transform=None):
self.root_dir = root_dir
self.transform = transform
self.image_filenames = os.listdir(root_dir)
def __len__(self):
return len(self.image_filenames)
def __getitem__(self, index):
# 读取图像
image_path = os.path.join(self.root_dir, self.image_filenames[index])
image = Image.open(image_path).convert('RGB')
# 对图像进行变换(如果有)
if self.transform is not None:
image = self.transform(image)
return image
ImageDataset is the inherited dataset.
The most important parts are these three parts: constructor, two magic methods (len, getitem)
If I still preprocess the image, the code example is as follows:
import os
from PIL import Image
import torchvision.transforms as transforms
from torch.utils.data import Dataset
class ImageDataset(Dataset):
def __init__(self, root_dir, transform=None):
self.root_dir = root_dir
self.transform = transform
self.image_filenames = os.listdir(root_dir)
def __len__(self):
return len(self.image_filenames)
def __getitem__(self, index):
# 读取图像
image_path = os.path.join(self.root_dir, self.image_filenames[index])
image = Image.open(image_path).convert('RGB')
# 对图像进行变换(如果有)
if self.transform is not None:
image = self.transform(image)
return image
# 定义变换函数
transform = transforms.Compose([
transforms.Resize((224, 224)), # 调整图像大小
transforms.ToTensor(), # 将图像转换为张量
transforms.Normalize( # 归一化图像
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]
)
])
# 创建数据集实例
dataset = ImageDataset(root_dir='path/to/images', transform=transform)
# 获取第一张图像数据
image = dataset[0]
Since transforms.ToTensor() can directly convert a PIL.Image.Image object into a tensor, you can directly use transforms.ToTensor() for conversion here, so you don’t need to convert PIL.Image.Image to ndarray separately. Specifically, in transforms.ToTensor(), the PIL.Image.Image object is first converted to a numpy.ndarray object and then converted to a tensor.
2. What does torch.utils.data.Dataset look like
Let's open the dataset class function and take a look:
class Dataset(Generic[T_co]):
r"""An abstract class representing a :class:`Dataset`.
All datasets that represent a map from keys to data samples should subclass
it. All subclasses should overwrite :meth:`__getitem__`, supporting fetching a
data sample for a given key. Subclasses could also optionally overwrite
:meth:`__len__`, which is expected to return the size of the dataset by many
:class:`~torch.utils.data.Sampler` implementations and the default options
of :class:`~torch.utils.data.DataLoader`.
.. note::
:class:`~torch.utils.data.DataLoader` by default constructs a index
sampler that yields integral indices. To make it work with a map-style
dataset with non-integral indices/keys, a custom sampler must be provided.
"""
functions: Dict[str, Callable] = {
}
def __getitem__(self, index) -> T_co:
raise NotImplementedError
def __add__(self, other: 'Dataset[T_co]') -> 'ConcatDataset[T_co]':
return ConcatDataset([self, other])
# No `def __len__(self)` default?
# See NOTE [ Lack of Default `__len__` in Python Abstract Base Classes ]
# in pytorch/torch/utils/data/sampler.py
def __getattr__(self, attribute_name):
if attribute_name in Dataset.functions:
function = functools.partial(Dataset.functions[attribute_name], self)
return function
else:
raise AttributeError
@classmethod
def register_function(cls, function_name, function):
cls.functions[function_name] = function
@classmethod
def register_datapipe_as_function(cls, function_name, cls_to_register, enable_df_api_tracing=False):
if function_name in cls.functions:
raise Exception("Unable to add DataPipe function name {} as it is already taken".format(function_name))
def class_function(cls, enable_df_api_tracing, source_dp, *args, **kwargs):
result_pipe = cls(source_dp, *args, **kwargs)
if isinstance(result_pipe, Dataset):
if enable_df_api_tracing or isinstance(source_dp, DFIterDataPipe):
if function_name not in UNTRACABLE_DATAFRAME_PIPES:
result_pipe = result_pipe.trace_as_dataframe()
return result_pipe
function = functools.partial(class_function, cls_to_register, enable_df_api_tracing)
cls.functions[function_name] = function
The interpretation of chatgpt is: the definition of the Dataset class, which is an abstract class, and all datasets representing keys to data samples should inherit from it. All subclasses should override the getitem method to support getting data samples for a given key. Subclasses can also optionally override the len method, which is used by many torch.utils.data.Sampler implementations and the default option of torch.utils.data.DataLoader to return the size of the dataset. If the key of the dataset is not an integer type, you need to provide a custom sampler (sampler) to make it compatible with torch.utils.data.DataLoader. In addition, the Dataset class also provides some methods and attributes, such as add method, getattr method and so on.
raise NotImplementedError indicates that the method has not been implemented and needs to be implemented in the subclass. In Python, using the NotImplementedError exception can conveniently prompt developers that the method has not been implemented, which is also a standardized implementation. Of course, you can also implement these two methods directly in the subclass instead of using NotImplementedError.
3. Summary of some used inherited dataset classes
3.1. In the image denoising task, use patch to divide a single image into multiple subimages for training
import torch
from torch.utils.data import Dataset
from torchvision import transforms
from PIL import Image
import os
class PatchDataset(Dataset):
def __init__(self, noisy_image_folder, clean_image_folder, patch_size=64):
self.noisy_image_folder = noisy_image_folder
self.clean_image_folder = clean_image_folder
self.patch_size = patch_size
self.transform = transforms.Compose([
transforms.Resize(patch_size + 16),
transforms.RandomCrop(patch_size),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.5],
std=[0.5]
)
])
self.noisy_image_paths = [os.path.join(noisy_image_folder, x) for x in os.listdir(noisy_image_folder)]
self.clean_image_paths = [os.path.join(clean_image_folder, x) for x in os.listdir(clean_image_folder)]
def __len__(self):
return len(self.noisy_image_paths)
def __getitem__(self, idx):
# 读取图像
noisy_image = Image.open(self.noisy_image_paths[idx]).convert('L')
clean_image = Image.open(self.clean_image_paths[idx]).convert('L')
# 对图像进行 patch 操作
patches = []
for i in range(4): # 每个图像分割成 4 个 patch
for j in range(4):
x = j * self.patch_size
y = i * self.patch_size
noisy_patch = noisy_image.crop((x, y, x + self.patch_size, y + self.patch_size))
clean_patch = clean_image.crop((x, y, x + self.patch_size, y + self.patch_size))
noisy_patch = self.transform(noisy_patch)
clean_patch = self.transform(clean_patch)
patches.append((noisy_patch, clean_patch))
return patches
In this example, we define a custom dataset class called PatchDataset that inherits from PyTorch's Dataset class. In the init function, we pass in the noisy image folder path noisy_image_folder, the noise-free image folder path clean_image_folder and the size patch_size of the patch, and define the transformation function self.transform. In the getitem function, we read the noisy image and the noise-free image and perform patch operation on them, and form a list of 16 patches obtained and return, each element is a patch containing noise and the corresponding noise-free patch of tuples.
Note that in this example we normalize both the noisy and non-noisy patches. This is because in the image denoising task, we need to input noisy images into the model for training, and at the same time need to use non-noisy images as labels for supervised learning. Therefore, performing the same normalization operation on noisy and noise-free images can simplify the code and improve the training effect.
3.2. In the HDR image reconstruction task, the transform in the dataset class should be like a horse (.hdr.exr file and jpg are opened differently)
import torch
from torch.utils.data import Dataset
from torchvision import transforms
import cv2
import os
class PatchDataset(Dataset):
def __init__(self, ldr_image_folder, hdr_image_folder, patch_size=64):
self.ldr_image_folder = ldr_image_folder
self.hdr_image_folder = hdr_image_folder
self.patch_size = patch_size
self.transform = transforms.Compose([
transforms.Resize(patch_size + 16),
transforms.RandomCrop(patch_size),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.5, 0.5, 0.5],
std=[0.5, 0.5, 0.5]
)
])
self.ldr_image_paths = [os.path.join(ldr_image_folder, x) for x in os.listdir(ldr_image_folder)]
self.hdr_image_paths = [os.path.join(hdr_image_folder, x) for x in os.listdir(hdr_image_folder)]
def __len__(self):
return len(self.ldr_image_paths)
def read_hdr_image(self, path):
# 读取 HDR 图像
hdr_image = cv2.imread(path, cv2.IMREAD_ANYDEPTH | cv2.IMREAD_COLOR)
# 将像素值恢复到原始范围
hdr_image = hdr_image / 65535.0 * 100.0 # 假设原始范围为 [0, 100]
return hdr_image
def __getitem__(self, idx):
# 读取图像
ldr_image = cv2.imread(self.ldr_image_paths[idx], cv2.IMREAD_COLOR)
hdr_image = self.read_hdr_image(self.hdr_image_paths[idx])
# 对图像进行 patch 操作
patches = []
for i in range(4): # 每个图像分割成 4 个 patch
for j in range(4):
x = j * self.patch_size
y = i * self.patch_size
ldr_patch = ldr_image[y:y+self.patch_size, x:x+self.patch_size, :]
hdr_patch = hdr_image[y:y+self.patch_size, x:x+self.patch_size, :]
ldr_patch = self.transform(ldr_patch)
hdr_patch = self.transform(hdr_patch)
patches.append((ldr_patch, hdr_patch))
return patches
We first define a read_hdr_image() function that uses OpenCV's imread() function to read the HDR image and restore the pixel values to their original range. In the getitem function, we read the LDR image and HDR image and perform patch operation on them, and form a list of 16 patches obtained and return it, where each element is a tuple containing the LDR patch and the corresponding HDR patch.
Because it is the code generated by chatgpt, whether it is correct or not needs to be verified on the computer: the
pixel value of the HDR image in .hdr or .exr format is usually a floating point number (for example, in OpenEXR, the data type of the pixel value is FLOAT), and usually not in 0 to 255, but in a larger range. The exact range depends on the acquisition device and the parameters used during image processing.
When using OpenCV to read an HDR image, it scales the pixel values to the range of 0 to 255, so we need to manually restore the pixel values to the original range. In this example, we assume the original range is [0, 100], so we divide the original floating point pixel value by 65535 (ie 2**16-1, since pixel values are stored as 16-bit integers in OpenCV) and multiply by 100 to get the pixel values back to their original range.