from PIL import Image
import numpy as np
from torch.utils.data import Dataset
import os
from torchvision import transforms
def convert_rgb_to_y(img, dim_order='hwc'):
if dim_order == 'hwc':
return 16. + (64.738 * img[..., 0] + 129.057 * img[..., 1] + 25.064 * img[..., 2]) / 256.
else:
return 16. + (64.738 * img[0] + 129.057 * img[1] + 25.064 * img[2]) / 256.
def is_image_file(filename):
return any(filename.endswith(extension) for extension in ['jpeg', 'JPEG', 'jpg', 'png', 'JPG', 'PNG', 'gif'])
class dataset(Dataset):
def __init__(self, path, scale):
super(dataset, self).__init__()
self.scale = scale
hr = sorted(os.listdir(path))
self.hr_name = [os.path.join(path, x) for x in hr if is_image_file(x)]
def __len__(self):
return len(self.hr_name)
def
图像超分辨率重建dataset的编写,直接从文件夹中读取
猜你喜欢
转载自blog.csdn.net/qq_40107571/article/details/127054391
今日推荐
周排行