def data_loader(FLAGS): with tf.device('/cpu:0'): # Define the returned data batches Data = collections.namedtuple('Data', 'paths_LR, paths_HR, inputs, targets, image_count, steps_per_epoch') #Check the input directory 数据目录核对 if (FLAGS.input_dir_LR == 'None') or (FLAGS.input_dir_HR == 'None'): raise ValueError('Input directory is not provided') if (not os.path.exists(FLAGS.input_dir_LR)) or (not os.path.exists(FLAGS.input_dir_HR)): raise ValueError('Input directory not found') image_list_LR = os.listdir(FLAGS.input_dir_LR) image_list_LR = [_ for _ in image_list_LR if _.endswith('.png')] if len(image_list_LR)==0: raise Exception('No png files in the input directory') # 创建图像Tensor image_list_LR_temp = sorted(image_list_LR) image_list_LR = [os.path.join(FLAGS.input_dir_LR, _) for _ in image_list_LR_temp] image_list_HR = [os.path.join(FLAGS.input_dir_HR, _) for _ in image_list_LR_temp] image_list_LR_tensor = tf.convert_to_tensor(image_list_LR, dtype=tf.string) image_list_HR_tensor = tf.convert_to_tensor(image_list_HR, dtype=tf.string) with tf.variable_scope('load_image'): # define the image list queue # image_list_LR_queue = tf.train.string_input_producer(image_list_LR, shuffle=False, capacity=FLAGS.name_queue_capacity) # image_list_HR_queue = tf.train.string_input_producer(image_list_HR, shuffle=False, capacity=FLAGS.name_queue_capacity) #print('[Queue] image list queue use shuffle: %s'%(FLAGS.mode == 'Train')) output = tf.train.slice_input_producer([image_list_LR_tensor, image_list_HR_tensor], shuffle=False, capacity=FLAGS.name_queue_capacity) # Reading and decode the images reader = tf.WholeFileReader(name='image_reader') image_LR = tf.read_file(output[0]) image_HR = tf.read_file(output[1]) input_image_LR = tf.image.decode_png(image_LR, channels=3) input_image_HR = tf.image.decode_png(image_HR, channels=3) input_image_LR = tf.image.convert_image_dtype(input_image_LR, dtype=tf.float32) input_image_HR = tf.image.convert_image_dtype(input_image_HR, dtype=tf.float32) assertion = tf.assert_equal(tf.shape(input_image_LR)[2], 3, message="image does not have 3 channels") with tf.control_dependencies([assertion]): input_image_LR = tf.identity(input_image_LR) input_image_HR = tf.identity(input_image_HR) # Normalize the low resolution image to [0, 1], high resolution to [-1, 1] a_image = preprocessLR(input_image_LR) b_image = preprocess(input_image_HR) inputs, targets = [a_image, b_image] # The data augmentation part 数据增强部分 with tf.name_scope('data_preprocessing'): with tf.name_scope('random_crop'): #随机裁剪 # Check whether perform crop if (FLAGS.random_crop is True) and FLAGS.mode == 'train': print('[Config] Use random crop') # Set the shape of the input image. the target will have 4X size input_size = tf.shape(inputs) target_size = tf.shape(targets) offset_w = tf.cast(tf.floor(tf.random_uniform([], 0, tf.cast(input_size[1], tf.float32) - FLAGS.crop_size)), dtype=tf.int32) offset_h = tf.cast(tf.floor(tf.random_uniform([], 0, tf.cast(input_size[0], tf.float32) - FLAGS.crop_size)), dtype=tf.int32) if FLAGS.task == 'SRGAN' or FLAGS.task == 'SRResnet': inputs = tf.image.crop_to_bounding_box(inputs, offset_h, offset_w, FLAGS.crop_size, FLAGS.crop_size) targets = tf.image.crop_to_bounding_box(targets, offset_h*4, offset_w*4, FLAGS.crop_size*4, FLAGS.crop_size*4) elif FLAGS.task == 'denoise': inputs = tf.image.crop_to_bounding_box(inputs, offset_h, offset_w, FLAGS.crop_size, FLAGS.crop_size) targets = tf.image.crop_to_bounding_box(targets, offset_h, offset_w, FLAGS.crop_size, FLAGS.crop_size) # Do not perform crop else: inputs = tf.identity(inputs) targets = tf.identity(targets) with tf.variable_scope('random_flip'): #随机翻转 # Check for random flip: if (FLAGS.flip is True) and (FLAGS.mode == 'train'): print('[Config] Use random flip') # Produce the decision of random flip decision = tf.random_uniform([], 0, 1, dtype=tf.float32) input_images = random_flip(inputs, decision) target_images = random_flip(targets, decision) else: input_images = tf.identity(inputs) target_images = tf.identity(targets) if FLAGS.task == 'SRGAN' or FLAGS.task == 'SRResnet': input_images.set_shape([FLAGS.crop_size, FLAGS.crop_size, 3]) target_images.set_shape([FLAGS.crop_size*4, FLAGS.crop_size*4, 3]) elif FLAGS.task == 'denoise': input_images.set_shape([FLAGS.crop_size, FLAGS.crop_size, 3]) target_images.set_shape([FLAGS.crop_size, FLAGS.crop_size, 3]) if FLAGS.mode == 'train': paths_LR_batch, paths_HR_batch, inputs_batch, targets_batch = tf.train.shuffle_batch([output[0], output[1], input_images, target_images], batch_size=FLAGS.batch_size, capacity=FLAGS.image_queue_capacity+4*FLAGS.batch_size, min_after_dequeue=FLAGS.image_queue_capacity, num_threads=FLAGS.queue_thread) else: paths_LR_batch, paths_HR_batch, inputs_batch, targets_batch = tf.train.batch([output[0], output[1], input_images, target_images], batch_size=FLAGS.batch_size, num_threads=FLAGS.queue_thread, allow_smaller_final_batch=True) steps_per_epoch = int(math.ceil(len(image_list_LR) / FLAGS.batch_size)) if FLAGS.task == 'SRGAN' or FLAGS.task == 'SRResnet': inputs_batch.set_shape([FLAGS.batch_size, FLAGS.crop_size, FLAGS.crop_size, 3]) targets_batch.set_shape([FLAGS.batch_size, FLAGS.crop_size*4, FLAGS.crop_size*4, 3]) elif FLAGS.task == 'denoise': inputs_batch.set_shape([FLAGS.batch_size, FLAGS.crop_size, FLAGS.crop_size, 3]) targets_batch.set_shape([FLAGS.batch_size, FLAGS.crop_size, FLAGS.crop_size, 3]) return Data( paths_LR=paths_LR_batch, paths_HR=paths_HR_batch, inputs=inputs_batch, targets=targets_batch, image_count=len(image_list_LR), steps_per_epoch=steps_per_epoch )
图像数据加载及预处理的笔记(一)
猜你喜欢
转载自blog.csdn.net/meailin/article/details/80033258
今日推荐
周排行