图像超分辨率重建dataset的编写,直接从文件夹中读取

from PIL import Image
import numpy as np
from torch.utils.data import Dataset
import os
from torchvision import transforms


def convert_rgb_to_y(img, dim_order='hwc'):
    if dim_order == 'hwc':
        return 16. + (64.738 * img[..., 0] + 129.057 * img[..., 1] + 25.064 * img[..., 2]) / 256.
    else:
        return 16. + (64.738 * img[0] + 129.057 * img[1] + 25.064 * img[2]) / 256.


def is_image_file(filename):
    return any(filename.endswith(extension) for extension in ['jpeg', 'JPEG', 'jpg', 'png', 'JPG', 'PNG', 'gif'])


class dataset(Dataset):
    def __init__(self, path, scale):
        super(dataset, self).__init__()
        self.scale = scale
        hr = sorted(os.listdir(path))
        self.hr_name = [os.path.join(path, x) for x in hr if is_image_file(x)]

    def __len__(self):
        return len(self.hr_name)

    def 

猜你喜欢

转载自blog.csdn.net/qq_40107571/article/details/127054391