图像数据预处理 -- 数据增强、写入tfrecords

Augmentor 是图像数据增强一个很好用的python库,支持多种图像变形变换。

  • 下面这段代码展示的是基于图像分割的数据集,同时生成增强的图像及其对应的label
import Augmentor

# 图像所在目录
AUGMENT_SOURCE_DIR = './dataset/augment_dataset/'
# 增强的图像的保存目录,此处好像只能用绝对路径==,算是一个小瑕疵 
AUGMENT_OUTPUT_DIR = 'F:/Pycharm Projects/Unet/dataset/augment_output'

def augment():
    if not os.path.exists(AUGMENT_OUTPUT_DIR):
        os.mkdir(AUGMENT_OUTPUT_DIR)
    # 获取每一张图像的路径
    filenames = glob.glob(os.path.join(AUGMENT_SOURCE_DIR, '*.png'))
    for filename in filenames:
        # 这里source_directory是单张图片,如果不需要同时生成标签,则这里直接填目录就好
        p = Augmentor.Pipeline(
            source_directory=filename,
            output_directory=AUGMENT_OUTPUT_DIR
        )
        # 图片对应的标签的目录,且二者必须同名(要自己预处理一下)
        p.ground_truth(ground_truth_directory=AUGMENT_LABEL_SOURCE_DIR)
        # 旋转:概率0.2
        p.rotate(probability=0.2, max_left_rotation=2, max_right_rotation=2)
        # 缩放
        p.zoom(probability=0.2, min_factor=1.1, max_factor=1.2)
        # 歪斜
        p.skew(probability=0.2)
        # 扭曲,注意grid_width, grid_height 不能超过原图
        p.random_distortion(probability=0.2, grid_width=20, grid_height=20, magnitude=1)
        # 四周裁剪
        p.shear(probability=0.2, max_shear_left=2, max_shear_right=2)
        # 随机裁剪
        p.crop_random(probability=0.2, percentage_area=0.8)
        # 翻转
        p.flip_random(probability=0.2)
        # 每张图片生成多少增强的图片
        p.sample(n=5)

augment()
  • 上述操作之后图像和标签会同时生成在同一文件夹(AUGMENT_OUTPUT_DIR)下面,其图像和对应的label命名是对应的,所以下面将二者分别转移到各自的文件夹下:
def standard_img_and_lbl(dir):
    filenames = glob.glob(dir + '/*.png')
    for idx, filename in enumerate(filenames):
        if 'image_original' in filename:
            label_name = filename.replace('image_original_', '_groundtruth_(1)_image_')
            img = cv2.imread(filename)
            lbl = cv2.imread(label_name)
            cv2.imwrite(os.path.join(AUGMENT_IMAGE_PATH, '%d.png'%idx), img)
            cv2.imwrite(os.path.join(AUGMENT_LABEL_PATH, '%d.png'%idx), lbl)
  • 将图像写成TFRecords形式保存:TFRecords文件是一种二进制文件,其不对数据进行压缩,所以可以被快速加载到内存中.格式不支持随机访问,因此它适合于大量的数据流,但不适用于快速分片或其他非连续存取
def write_image_to_tfrecords():
    # image / label 各自的存储文件夹
    augment_image_path = AUGMENT_IMAGE_PATH
    augment_label_path = AUGMENT_LABEL_PATH
    # 要生成的文件:train、validation、predict
    train_set_writer = tf.python_io.TFRecordWriter(os.path.join('./dataset/my_set', TRAIN_SET_NAME))
    validation_set_writer = tf.python_io.TFRecordWriter(os.path.join('./dataset/my_set', VALIDATION_SET_NAME))
    predict_set_writer = tf.python_io.TFRecordWriter(os.path.join('./dataset/my_set', PREDICT_SET_NAME))

    # train set
    for idx in range(TRAIN_SET_SIZE):
        train_image = cv2.imread(os.path.join(augment_image_path, '%d.png' % idx))
        train_label = cv2.imread(os.path.join(augment_label_path, '%d.png' % idx), 0)
        train_image = cv2.resize(train_image, (INPUT_WIDTH, INPUT_HEIGHT))
        train_label = cv2.resize(train_label, (INPUT_WIDTH, INPUT_HEIGHT))
        train_label[train_label != 0] = 1
        example = tf.train.Example(features=tf.train.Features(feature={
            'label': tf.train.Feature(bytes_list=tf.train.BytesList(value=[train_label.tobytes()])),
            'image_raw': tf.train.Feature(bytes_list=tf.train.BytesList(value=[train_image.tobytes()]))
        }))     # example对象对label和image数据进行封装
        train_set_writer.write(example.SerializeToString())
        if idx % 100 == 0:
            print('Done train_set writing %.2f%%' % (idx / TRAIN_SET_SIZE * 100))
    train_set_writer.close()
    print('Done test set writing.')

    # validation set
    for idx in range(TRAIN_SET_SIZE, TRAIN_SET_SIZE + VALIDATION_SET_SIZE):
        validation_image = cv2.imread(os.path.join(augment_image_path, '%d.png' % idx))
        validation_label = cv2.imread(os.path.join(augment_label_path, '%d.png' % idx), 0)
        validation_image = cv2.resize(validation_image, (INPUT_WIDTH, INPUT_HEIGHT))
        validation_label = cv2.resize(validation_label, (INPUT_WIDTH, INPUT_HEIGHT))
        validation_label[validation_label != 0] = 1

        example = tf.train.Example(features=tf.train.Features(feature={
            'label': tf.train.Feature(bytes_list=tf.train.BytesList(value=[validation_label.tobytes()])),
            'image_raw': tf.train.Feature(bytes_list=tf.train.BytesList(value=[validation_image.tobytes()]))
        }))
        validation_set_writer.write(example.SerializeToString())  # 序列化为字符串
        if idx % 10 == 0:
            print('Done validation_set writing %.2f%%' % ((idx - TRAIN_SET_SIZE) / VALIDATION_SET_SIZE * 100))
    validation_set_writer.close()
    print("Done validation_set writing")

    # predict set
    predict_image_path = ORIGIN_PREDICT_IMG_DIR
    predict_label_path = ORIGIN_PREDICT_LBL_DIR
    for idx in range(PREDICT_SET_SIZE):
        predict_image = cv2.imread(os.path.join(predict_image_path, '%d.png'%idx))
        predict_label = cv2.imread(os.path.join(predict_label_path, '%d.png'%idx), 0)
        predict_image = cv2.resize(predict_image, (INPUT_WIDTH, INPUT_HEIGHT))
        predict_label = cv2.resize(predict_label, (OUTPUT_WIDTH, OUTPUT_HEIGHT))
        predict_label[predict_label != 0] = 1
        example = tf.train.Example(features=tf.train.Features(feature={
            'label': tf.train.Feature(bytes_list=tf.train.BytesList(value=[predict_label.tobytes()])),
            'image_raw': tf.train.Feature(bytes_list=tf.train.BytesList(value=[predict_image.tobytes()]))
        }))
        predict_set_writer.write(example.SerializeToString())
        if idx % 10 == 0:
            print('Done predict_set writing %.2f%%' % (idx / PREDICT_SET_SIZE * 100))
    predict_set_writer.close()
    print("Done predict_set writing")
  • 读取并验证TFRecords文件是否存储正确:
INPUT_WIDTH, INPUT_HEIGHT, INPUT_CHANNEL = 512, 512, 3
OUTPUT_WIDTH, OUTPUT_HEIGHT, OUTPUT_CHANNEL = 512, 512, 1
TRAIN_SET_NAME = 'train_set.tfrecords'
TFRECORDS_DIR = './dataset/my_set'


# 读取图像及其对应的label
def read_image(file_queue):
    # 用于读取TFRecord的类
    reader = tf.TFRecordReader()
    _, serialized_example = reader.read(file_queue)
    # 解析文件
    features = tf.parse_single_example(
            serialized_example,
            features={
                'label': tf.FixedLenFeature([], tf.string),
                'image_raw': tf.FixedLenFeature([], tf.string)
            }
    )
    # 解码为 uint8 的图像格式
    image = tf.decode_raw(features['image_raw'], tf.uint8)
    image = tf.reshape(image, [INPUT_WIDTH, INPUT_HEIGHT, INPUT_CHANNEL])
    label = tf.decode_raw(features['label'], tf.uint8)
    label = tf.reshape(label, [OUTPUT_WIDTH, OUTPUT_HEIGHT])
    return image, label


# 显示图像和label
def read_check_tfrecords():
    train_file_path = os.path.join(TFRECORDS_DIR, TRAIN_SET_NAME)
    train_image_filename_queue = tf.train.string_input_producer(
            string_tensor=tf.train.match_filenames_once(train_file_path),
            num_epochs=1,
            shuffle=True
    )
    train_images, train_labels = read_image(train_image_filename_queue)
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        sess.run(tf.local_variables_initializer())
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(coord=coord)
        example, label = sess.run([train_images, train_labels])
        cv2.imshow('image', example)
        cv2.imshow('label', label)
        cv2.waitKey(0)
        coord.request_stop()
        coord.join(threads)
    print('Done reading and checking.')
    
# read_check_tfrecords()

猜你喜欢

转载自blog.csdn.net/francislucien2017/article/details/86991610