数据增强:裁剪并合并两张图片——Crop and Joint

1、目的

有些时候,我们获取到的数据场景有限,且感兴趣目标在原图所占比例太小,或者目标的位置不平衡(如:大多数目标都在图像中间,边缘的目标几乎没有)。

那么,我们就需要对原始的数据集进行一些数据增强,除了常见的平移、旋转、裁剪、调整亮度等,处理上述问题的一个有效方法就是把不同的图片拼接起来,这样,就可以把各个目标强行拉到中心点之外,这就解决了位置不均衡的问题;而在拼接之前,对目标之外的背景部分进行裁剪,则可以是目标在整个图片中的占比提高,这就解决了原图目标太小,大部分都是背景的问题。

拼接后的效果如下:

2、裁剪并拼接——Crop and Joint代码

下面是完整的代码,只需要指定两个参数:raw_data_path、new_data_path,分别为原始数据集路径和增强后的数据存放路径。

其中,原始数据集是图片、xml文件存放到同一个文件夹下的,形如:

"""
Crop and joint for two or more images, so that get a new image include their target boxes.
"""
import numpy as np
import random
import xml.etree.ElementTree as ET
import cv2
import os
import argparse

from bboxes2xml import bboxes2xml


def list_dir(path, list_name, suffix='xml'):  # 传入存储的list
    for file in os.listdir(path):
        file_path = os.path.join(path, file)
        if os.path.isdir(file_path):
            list_dir(file_path, list_name)
        else:
            if file_path.split('.')[-1] == suffix:
                file_path = file_path.replace('\\', '/')
                list_name.append(file_path)


def get_bboxes(xml_path):
    tree = ET.parse(open(xml_path, 'rb'))
    root = tree.getroot()
    bboxes, cls = [], []
    for obj in root.iter('object'):
        obj_cls = obj.find('name').text
        xmlbox = obj.find('bndbox')
        xmin = float(xmlbox.find('xmin').text)
        xmax = float(xmlbox.find('xmax').text)
        ymin = float(xmlbox.find('ymin').text)
        ymax = float(xmlbox.find('ymax').text)
        bboxes.append([xmin, ymin, xmax, ymax])
        cls.append(obj_cls)
    bboxes = np.asarray(bboxes, np.int)
    return bboxes, cls


def crop_img(img_path, xml_path):
    img = cv2.imread(img_path)
    h_img, w_img, _ = img.shape
    bboxes, cls = get_bboxes(xml_path)

    # 当图中无Bbox时,将长边裁剪到与短边等长
    if len(cls) == 0:
        diff_half = abs(h_img-w_img)//2
        if h_img >= w_img:
            img = img[diff_half:-diff_half, :, :]
        else:
            img = img[:, diff_half:-diff_half, :]
        return img, bboxes, cls

    # 得到可以包含所有bbox的最大bbox
    max_bbox = np.concatenate([np.min(bboxes[:, 0:2], axis=0),
                               np.max(bboxes[:, 2:4], axis=0)], axis=-1)

    # 随机crop,并保证crop后的img仍包含max_bbox
    # top crop:
    top_crop_total = max(max_bbox[1] - 10, 0)
    top = random.randint(0, top_crop_total)
    # left crop:
    left_crop_total = max(max_bbox[0] - 10, 0)
    left = random.randint(0, left_crop_total)
    # down crop:
    down_crop_total = max(h_img-max_bbox[3] - 10, 0)
    down = random.randint(0, down_crop_total)
    # right crop:
    right_crop_total = max(w_img-max_bbox[2] - 10, 0)
    right = random.randint(0, right_crop_total)

    # 得到crop后的子图在原图中的范围
    x1, y1, x2, y2 = left, top, w_img-right, h_img-down

    # crop img and bboxes
    img = img[y1:y2, x1:x2, :]
    bboxes[:, [0, 2]] = bboxes[:, [0, 2]] - x1
    bboxes[:, [1, 3]] = bboxes[:, [1, 3]] - y1

    return img, bboxes, cls


def joint_imgs(img1, bboxes1, cls1, img2, bboxes2, cls2):
    # 设置合并后的图片尺寸
    W = img1.shape[1] + img2.shape[1]
    H = max(img1.shape[0], img2.shape[0])

    # zero img
    img = np.zeros((H, W, 3), dtype=np.uint8)

    # joint imgs
    img[0:img1.shape[0], 0:img1.shape[1]] = img1
    img[0:img2.shape[0], img1.shape[1]:] = img2

    cls = cls1 + cls2
    if len(cls2) != 0:
        bboxes2[:, [0, 2]] = bboxes2[:, [0, 2]] + img1.shape[1]
        if len(cls1) == 0:
            return img, bboxes2, cls
        else:
            bbxes = np.vstack((bboxes1, bboxes2))
            return img, bbxes, cls
    else:
        if len(cls1) == 0:
            return img, bboxes2, cls
        else:
            return img, bboxes1, cls


def vis_bboxes(filename, img, bboxes, cls, if_print=False):

    if if_print:
        print("filename:{}\nbboxes:{}\nclasses:{}".format(filename, bboxes, cls))

    for j, p in enumerate(bboxes):
        [x1, y1, x2, y2] = p[:4]
        cv2.rectangle(img, (x1, y1), (x2, y2), (0, 0, 255))
        cv2.putText(img, cls[j], (x1, y1 + 10), cv2.FONT_HERSHEY_PLAIN, 1,
                    [0, 0, 255], 1)
    cv2.imwrite('crop_vis/' + filename.split('/')[-1], img)


def test_crop(args):
    xmls = []
    list_dir(args.raw_data_path, xmls, suffix='xml')
    imgs = [xml.replace('.xml', '.jpg') for xml in xmls]
    xml1 = random.choice(xmls)
    xml2 = random.choice(xmls)
    img1 = xml1.replace('.xml', '.jpg')
    img2 = xml2.replace('.xml', '.jpg')
    img1_crop, bboxes1_crop, cls1 = crop_img(img1, xml1)
    img2_crop, bboxes2_crop, cls2 = crop_img(img2, xml2)
    img, bbxes, cls = joint_imgs(img1_crop, bboxes1_crop, cls1, img2_crop, bboxes2_crop, cls2)

    # 可视化并保存
    # vis_bboxes(img1, img1_crop, bboxes1_crop, cls1)
    # vis_bboxes(img2, img2_crop, bboxes2_crop, cls2)
    vis_bboxes('joint.jpg', img, bbxes, cls)

def main(args):

    os.makedirs(args.new_data_path, exist_ok=True)
    xmls = []
    list_dir(args.raw_data_path, xmls, suffix='xml')
    # imgs = [xml.replace('.xml', '.jpg') for xml in xmls]

    for i, xml in enumerate(xmls):
        xml1 = xml
        xml2 = random.choice(xmls)

        img1 = xml1.replace('.xml', '.jpg')
        img2 = xml2.replace('.xml', '.jpg')

        img1_crop, bboxes1_crop, cls1 = crop_img(img1, xml1)
        img2_crop, bboxes2_crop, cls2 = crop_img(img2, xml2)
        img, bboxes, cls = joint_imgs(img1_crop, bboxes1_crop, cls1, img2_crop, bboxes2_crop, cls2)

        # save crop-joint img-xml pairs to new dataset path
        joint_name = 'joint_'+img1.split('/')[-1].replace('.jpg', '')
        img_save_path = os.path.join(args.new_data_path, joint_name+'.jpg')
        cv2.imwrite(img_save_path, img)
        gts = [[c]+b.tolist() for c, b in zip(cls, bboxes)]
        bboxes2xml(folder=args.new_data_path.split('/')[-1], img_name=joint_name,
                   width=img.shape[1], height=img.shape[0],
                   gts=gts, xml_save_to=args.new_data_path)

        # visualize crop_joint imgs
        vis_bboxes(joint_name+'.jpg', img, bboxes, cls, if_print=True)



def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--raw_data_path", default="path to raw_data_path", type=str,
                        help="raw dataset files")
    parser.add_argument("--new_data_path", default="path to new_data_path", type=str,
                        help="generated new dataset files")
    return parser.parse_args()


if __name__ == '__main__':
    args = parse_args()
    main(args)
    # test_crop(args)

猜你喜欢

转载自blog.csdn.net/oYeZhou/article/details/111252723