学习日志(十七):patch

import argparse
import glob
import h5py
import numpy as np
import PIL.Image as pil_image
from utils import convert_rgb_to_y

parser = argparse.ArgumentParser()
parser.add_argument('--images-dir', type=str, default='/home/radio/DS/SRCNN-pytorch-master/train_out')
parser.add_argument('--output-path', type=str, default='/home/radio/DS/SRCNN-pytorch-master/out.h5')
parser.add_argument('--patch-size', type=int, default=320)
parser.add_argument('--stride', type=int, default=320)
parser.add_argument('--scale', type=int, default=2)
args = parser.parse_args()


def train(args):
    h5_file = h5py.File(args.output_path, 'w')

    lr_patches = []
    hr_patches = []

    #搜索指定文件夹下的文件并排序
    for image_path in sorted(glob.glob('{}/*'.format(args.images_dir))):
        hr = pil_image.open(image_path).convert('RGB')
        print('hr:',hr.size)
        hr_width = (hr.width // args.scale) * args.scale
        hr_height = (hr.height // args.scale) * args.scale #先预处理,这样就没有多余的像素流出
        hr = hr.resize((hr_width, hr_height), resample=pil_image.BICUBIC)
        lr = hr.resize((hr_width // args.scale, hr_height // args.scale), resample=pil_image.BICUBIC)
        lr = lr.resize((lr.width * args.scale, lr.height * args.scale), resample=pil_image.BICUBIC)#缩放处理
        hr = np.array(hr).astype(np.float32)
        lr = np.array(lr).astype(np.float32)#转换成np格式,便于转换处理
        hr = convert_rgb_to_y(hr)
        lr = convert_rgb_to_y(lr)
        # print('hr.shape:',hr.shape)
        print('lr.shape:',lr.shape)

        for i in range(0, lr.shape[0] - args.patch_size + 1, args.stride):
            for j in range(0, lr.shape[1] - args.patch_size + 1, args.stride):
                lr_patches.append(lr[i:i + args.patch_size, j:j + args.patch_size])
                hr_patches.append(hr[i:i + args.patch_size, j:j + args.patch_size])
                print(np.array(lr_patches).shape)

    lr_patches = np.array(lr_patches)
        # print('lr_patches:',lr_patches.shape)
    hr_patches = np.array(hr_patches)
        # print('hr_patches:', hr_patches.shape)
    print(np.array(lr_patches).shape)
    h5_file.create_dataset('lr', data=lr_patches)
    h5_file.create_dataset('hr', data=hr_patches)

    h5_file.close()


if __name__ == '__main__':
    train(args)

说明:使用的图片是裁减之后大小为320x320的图片,patch_size为320x320,之后写入提前创建好的h5文件。

猜你喜欢

转载自blog.csdn.net/weixin_44825185/article/details/109449687