小目标检测学习

1.基于copy_pasted策略的Data_Augmentation

此部分学习了Github Data_Augmentation_Zoo_for_Object_Detection的心得,相关paper见Augmentation for small object detection

代码的思路很简单:

  • 首先获取各object标签值;通过标签值判断是否是小目标;
  • 如果是小目标,在图片中随机找一个相同大小的空间(该空间与object没有overlap),将该小目标空间的像素复制过去。
  • 实现中还加入了几个随机性,这里不再细述。

Github Data_Augmentation_Zoo_for_Object_Detection
中此部分的代码整理了下,并用真实例子进行了验证;

1.1 代码

1.1.1 加载相关库及设置变量、参数

import cv2
import numpy as np
import random

import matplotlib.pyplot as plt
colors = [(39, 129, 113), (164, 80, 133), (83, 122, 114), ]

"""   SMALL OBJECT AUGMENTATION   """
SMALL_OBJECT_AUGMENTATION = True
SOA_THRESH = 2000000  # 64*64,根据个人数据集情况调节
SOA_PROB = 1
SOA_COPY_TIMES = 3
SOA_EPOCHS = 30
SOA_ONE_OBJECT = False
SOA_ALL_OBJECTS = False

1.1.2 图片及bbox的显示

def bbox_to_rect(bbox, color):
    return plt.Rectangle(xy=(bbox[0], bbox[1]), width=bbox[2]-bbox[0], height=bbox[3]-bbox[1],
                         fill=False, edgecolor=color, linewidth=2)

def easy_visualization(sample):
    image, annots = sample['img'], sample['annot']
    fig = plt.imshow(image)
    for i in range(len(annots)):
        annot = [int(x) for x in annots[i]]
        label = annot[4]
        color = [c/255.0 for c in colors[label]]
        rect = bbox_to_rect(annot, color)
        fig.axes.add_patch(rect)
    plt.show()

1.1.3 实现SmallObject_Augmentation核心部分

class SmallObjectAugmentation(object):
    def __init__(self, thresh=64*64, prob=0.5, copy_times=3, epochs=30, all_objects=False, one_object=False):
        """
        sample = {'img':img, 'annot':annots}
        img = [height, width, 3]
        annot = [xmin, ymin, xmax, ymax, label]
        thresh:the detection threshold of the small object. If annot_h * annot_w < thresh, the object is small
        prob: the prob to do small object augmentation
        epochs: the epochs to do
        """
        self.thresh = thresh
        self.prob = prob
        self.copy_times = copy_times
        self.epochs = epochs
        self.all_objects = all_objects
        self.one_object = one_object
        if self.all_objects or self.one_object:
            self.copy_times = 1

    def issmallobject(self, h, w):
        if h * w <= self.thresh:
            return True
        else:
            return False

    def compute_overlap(self, annot_a, annot_b):
        if annot_a is None:
            return False
        left_max = max(annot_a[0], annot_b[0])
        top_max = max(annot_a[1], annot_b[1])
        right_min = min(annot_a[2], annot_b[2])
        bottom_min = min(annot_a[3], annot_b[3])
        inter = max(0, (right_min-left_max)) * max(0, (bottom_min-top_max))
        if inter != 0:
            return True
        else:
            return False

    def donot_overlap(self, new_annot, annots):
        for annot in annots:
            if self.compute_overlap(new_annot, annot):
                return False
        return True

    def create_copy_annot(self, h, w, annot, annots):
        annot = annot.astype(np.int)
        annot_h, annot_w = annot[3] - annot[1], annot[2] - annot[0]
        for epoch in range(self.epochs):
            random_x, random_y = np.random.randint(int(annot_w / 2), int(w - annot_w / 2)), \
                np.random.randint(int(annot_h / 2), int(h - annot_h / 2))
            xmin, ymin = random_x - annot_w / 2, random_y - annot_h / 2
            xmax, ymax = xmin + annot_w, ymin + annot_h
            if xmin < 0 or xmax > w or ymin < 0 or ymax > h:
                continue
            new_annot = np.array([xmin, ymin, xmax, ymax, annot[4]], dtype=int)
            print("new_annot:", new_annot)

            if self.donot_overlap(new_annot, annots) is False:
                continue

            return new_annot
        return None

    def add_patch_in_img(self, annot, copy_annot, image):
        copy_annot = copy_annot.astype(np.int)
        image[annot[1]:annot[3], annot[0]:annot[2], :] = image[copy_annot[1]:copy_annot[3], copy_annot[0]:copy_annot[2], :]
        return image

    def __call__(self, sample):
        if self.all_objects and self.one_object:
            return sample
        if np.random.rand() > self.prob:
            return sample

        img, annots = sample['img'], sample['annot']
        h, w = img.shape[0], img.shape[1]

        small_object_list = list()
        for idx in range(annots.shape[0]):
            annot = annots[idx]
            annot_h, annot_w = annot[2]-annot[0], annot[3]-annot[1]
            if self.issmallobject(annot_h, annot_w):
                small_object_list.append(idx)

        num = len(small_object_list)
        # No Small Object
        if num == 0:
            return sample

        # Refine the copy_object by the given policy
        # Policy 2:
        copy_object_num = 1 if num == 1 else np.random.randint(num)
        # Policy 3:
        if self.all_objects or self.one_object:
            copy_object_num = num

        random_list = random.sample(range(num), copy_object_num)
        annot_idx_of_small_object = [
            small_object_list[idx] for idx in random_list]
        select_annots = annots[annot_idx_of_small_object, :]
        annots = annots.tolist()
        for idx in range(copy_object_num):
            annot = select_annots[idx]
            # annot_w, annot_h = annot[3] - annot[1], annot[2] - annot[0]
            # if self.issmallobject(annot_h, annot_w) is False: continue
            for i in range(self.copy_times):
                new_annot = self.create_copy_annot(h, w, annot, annots,)
                if new_annot is not None:
                    img = self.add_patch_in_img(new_annot, annot, img)
                    annots.append(new_annot)

        return {
    
    'img': img, 'annot': np.array(annots)}

1.1.4 用真实图片调用程序查看效果

if __name__ == '__main__':

    img = cv2.imread(r"D://1.jpg", -1)
    Boxes = np.array([[196, 175, 387, 294, 0], [124, 298, 222, 399, 0], [
                     334, 298, 434, 390, 0]], dtype=int)

    transform = SmallObjectAugmentation(
        SOA_THRESH, SOA_PROB, SOA_COPY_TIMES, SOA_EPOCHS, SOA_ALL_OBJECTS, SOA_ONE_OBJECT)
    sample = {
    
    'img': img, 'annot': Boxes}
    sample = transform(sample)
    easy_visualization(sample)
    cv2.imwrite("D:\\2.jpg", img)

这里对图片D://1.jpg中的小鸟进行 Data_Augmentation,处理前后效果如下:

在这里插入图片描述
非常nice!

注意程序中有多个random相关的语句,输出结果每次都不一样!

另外,此程序的输入Boxes 不是我们常使用的VOC 相对坐标(除以高/宽的相对样式),在目标检测中要进行一些修改。可参考本人上传资源在目标检测模型中直接调用。

参考文献

[1] https://github.com/zzl-pointcloud/Data_Augmentation_Zoo_for_Object_Detection
[2] Perper: Augmentation for small object detection

Guess you like

Origin blog.csdn.net/WANGWUSHAN/article/details/118422175