VOC2012数据清洗

XML文件data clean

xml文件读取中出现了3个问题:

  1. xml中某些坐标值不是整数

  2. 某些xml不仅包含目标边框的坐标,还包括了目标part的坐标(满足其他应用需求)

  3. xml中box坐标的存储不一定按(xmin, ymin, xmax, ymax)顺序存放,可能完全打乱顺序。

针对以上问题,为了制备检测所需的数据集,需要对以上情形进行处理。

坐标值不是整数

整个数据集中,只有2011_0033532011_006777坐标值为小数,将其删除即可。

检测时使用read_xml_gtbox_and_label()依次读取xml文件,若出错则存在该问题。

def read_xml_gtbox_and_label(xml_path):

"""
:param xml_path: the path of voc xml
:return: a list contains gtboxes and labels, shape is [num_of_gtboxes, 5],
       and has [xmin, ymin, xmax, ymax, label] in a per row
"""

tree = ET.parse(xml_path)
root = tree.getroot()
img_width = None
img_height = None
box_list = []
for child_of_root in root:
    # if child_of_root.tag == 'filename':
    #     assert child_of_root.text == xml_path.split('/')[-1].split('.')[0] \
    #                                  + FLAGS.img_format, 'xml_name and img_name cannot match'

    if child_of_root.tag == 'size':
        for child_item in child_of_root:
            if child_item.tag == 'width':
                img_width = int(child_item.text)
            if child_item.tag == 'height':
                img_height = int(child_item.text)

    if child_of_root.tag == 'object':
        label = None
        for child_item in child_of_root:
            if child_item.tag == 'name':
                label = NAME_LABEL_MAP[child_item.text]
            if child_item.tag == 'bndbox':
                tmp_box = []
                for node in child_item:
                    tmp_box.append(int(node.text))  # [x1, y1. x2, y2]
                x1 = tmp_box[0]
                y1 = tmp_box[1]
                x2 = tmp_box[2]
                y2 = tmp_box[3]
                if x1 >= x2 or y1 >= y2:
                    print(xml_path)
                    with open("disabled_data.txt", 'a') as f:
                        f.writelines(xml_path)
                        f.writelines('\n')
                assert label is not None, 'label is none, error'
                tmp_box.append(label)  # [x1, y1. x2, y2, label]
                box_list.append(tmp_box)


gtbox_label = np.array(box_list, dtype=np.int32)  # [x1, y1. x2, y2, label]

xmin, ymin, xmax, ymax, label = gtbox_label[:, 0], gtbox_label[:, 1], gtbox_label[:, 2], gtbox_label[:, 3], \
                                gtbox_label[:, 4]

gtbox_label = np.transpose(np.stack([ymin, xmin, ymax, xmax, label], axis=0))  # [ymin, xmin, ymax, xmax, label]

return img_height, img_width, gtbox_label

数据中包括非目标的box

该问题的影响在于不能使用bndbox = objects.getElementsByTagName('bndbox')来获取坐标,原因在于该语句会将object之下的part部分的box也读取进来,而非只读取object下的box。

错误读取函数代码为demo_xml_read_xy_tag_wrong.py.

为了不读取part部分的数据,我们使用if child_of_root.tag == 'object':以及for child_item in child_of_root:实现仅对object tag下的box进行读取。

数据不严格按照(xmin, ymin, xmax, ymax)进行存放

使用if child_of_root.tag == 'object':以及for child_item in child_of_root:实现仅对object tag下的box进行读取时,同时依据box下的node的node.tag来对xmin, ymin, xmax, ymax分别进行赋值。

实现代码为demo_work_xml_read.py

import xml.etree.cElementTree as ET
import numpy as np
import cv2
import copy
import glob
import os
import matplotlib.pyplot as plt

NAME_LABEL_MAP = {
    'back_ground': 0,
    'aeroplane': 1,
    'bicycle': 2,
    'bird': 3,
    'boat': 4,
    'bottle': 5,
    'bus': 6,
    'car': 7,
    'cat': 8,
    'chair': 9,
    'cow': 10,
    'diningtable': 11,
    'dog': 12,
    'horse': 13,
    'motorbike': 14,
    'person': 15,
    'pottedplant': 16,
    'sheep': 17,
    'sofa': 18,
    'train': 19,
    'tvmonitor': 20
}

# parameters
xml_path = "./xml"
img_format = ".jpg"
image_path = "./image"


def read_xml_gtbox_and_label(xml_path):

    """
    :param xml_path: the path of voc xml
    :return: a list contains gtboxes and labels, shape is [num_of_gtboxes, 5],
        and has [xmin, ymin, xmax, ymax, label] in a per row
    """

    tree = ET.parse(xml_path)
    root = tree.getroot()
    img_width = None
    img_height = None
    box_list = []
    for child_of_root in root:
        # if child_of_root.tag == 'filename':
        #     assert child_of_root.text == xml_path.split('/')[-1].split('.')[0] \
        #                                  + FLAGS.img_format, 'xml_name and img_name cannot match'

        if child_of_root.tag == 'size':
            for child_item in child_of_root:
                if child_item.tag == 'width':
                    img_width = int(child_item.text)
                if child_item.tag == 'height':
                    img_height = int(child_item.text)

        if child_of_root.tag == 'object':
            label = None
            for child_item in child_of_root:
                if child_item.tag == 'name':
                    label = NAME_LABEL_MAP[child_item.text]
                if child_item.tag == 'bndbox':
                    tmp_box = []
                    for node in child_item:
                        if node.tag == 'xmin':
                            x1 = int(node.text)
                        elif node.tag == 'xmax':
                            x2 = int(node.text)
                        elif node.tag == 'ymin':
                            y1 = int(node.text)
                        elif node.tag == 'ymax':
                            y2 = int(node.text)
                        else:
                            raise ValueError("invalid tag name in bndbox")
                    tmp_box = [x1, y1, x2, y2]  # [x1, y1. x2, y2]
                    print("tmp_box: ", tmp_box)
                    # x1 = tmp_box[0]
                    # y1 = tmp_box[1]
                    # x2 = tmp_box[2]
                    # y2 = tmp_box[3]
                    if x1 >= x2 or y1 >= y2:
                        print(xml_path)
                        with open("disabled_data.txt", 'a') as f:
                            f.writelines(xml_path)
                            f.writelines('\n')
                    assert label is not None, 'label is none, error'
                    tmp_box.append(label)  # [x1, y1. x2, y2, label]
                    box_list.append(tmp_box)


    gtbox_label = np.array(box_list, dtype=np.int32)  # [x1, y1. x2, y2, label]

    xmin, ymin, xmax, ymax, label = gtbox_label[:, 0], gtbox_label[:, 1], gtbox_label[:, 2], gtbox_label[:, 3], \
                                    gtbox_label[:, 4]

    gtbox_label = np.transpose(np.stack([ymin, xmin, ymax, xmax, label], axis=0))  # [ymin, xmin, ymax, xmax, label]

    return img_height, img_width, gtbox_label


if __name__ == '__main__':
    for count, xml in enumerate(glob.glob(xml_path + '/*.xml')):
        # to avoid path error in different development platform
        xml = xml.replace('\\', '/')

        img_name = xml.split('/')[-1].split('.')[0] + img_format
        img_path = image_path + '/' + img_name

        if not os.path.exists(img_path):
            print('{} is not exist!'.format(img_path))
            continue

        img_height, img_width, gtbox_label = read_xml_gtbox_and_label(xml)

        # visualize boxes in img:
        img = cv2.imread(img_path)
        img_show = copy.deepcopy(img)
        for bbox_and_label in gtbox_label:
            xmin = bbox_and_label[0]
            ymin = bbox_and_label[1]
            xmax = bbox_and_label[2]
            ymax = bbox_and_label[3]
            label = bbox_and_label[4]

            cv2.rectangle(img_show, (ymin, xmin), (ymax, xmax), (0, 255, 0), 2)

        plt.figure()
        plt.imshow(img_show, 'brg')
        plt.show()

注意,这里的 清洗 并不改变原始数据,而是改变了数据读取的代码。由于是网络训练时由数据引起的问题,所以放在了数据清洗部分。

猜你喜欢

转载自blog.csdn.net/u010103202/article/details/81210176