XML文件data clean
xml文件读取中出现了3个问题:
xml中某些坐标值不是整数
某些xml不仅包含目标边框的坐标,还包括了目标part的坐标(满足其他应用需求)
xml中box坐标的存储不一定按(xmin, ymin, xmax, ymax)顺序存放,可能完全打乱顺序。
针对以上问题,为了制备检测所需的数据集,需要对以上情形进行处理。
坐标值不是整数
整个数据集中,只有2011_003353
和2011_006777
坐标值为小数,将其删除即可。
检测时使用read_xml_gtbox_and_label()
依次读取xml文件,若出错则存在该问题。
def read_xml_gtbox_and_label(xml_path):
"""
:param xml_path: the path of voc xml
:return: a list contains gtboxes and labels, shape is [num_of_gtboxes, 5],
and has [xmin, ymin, xmax, ymax, label] in a per row
"""
tree = ET.parse(xml_path)
root = tree.getroot()
img_width = None
img_height = None
box_list = []
for child_of_root in root:
# if child_of_root.tag == 'filename':
# assert child_of_root.text == xml_path.split('/')[-1].split('.')[0] \
# + FLAGS.img_format, 'xml_name and img_name cannot match'
if child_of_root.tag == 'size':
for child_item in child_of_root:
if child_item.tag == 'width':
img_width = int(child_item.text)
if child_item.tag == 'height':
img_height = int(child_item.text)
if child_of_root.tag == 'object':
label = None
for child_item in child_of_root:
if child_item.tag == 'name':
label = NAME_LABEL_MAP[child_item.text]
if child_item.tag == 'bndbox':
tmp_box = []
for node in child_item:
tmp_box.append(int(node.text)) # [x1, y1. x2, y2]
x1 = tmp_box[0]
y1 = tmp_box[1]
x2 = tmp_box[2]
y2 = tmp_box[3]
if x1 >= x2 or y1 >= y2:
print(xml_path)
with open("disabled_data.txt", 'a') as f:
f.writelines(xml_path)
f.writelines('\n')
assert label is not None, 'label is none, error'
tmp_box.append(label) # [x1, y1. x2, y2, label]
box_list.append(tmp_box)
gtbox_label = np.array(box_list, dtype=np.int32) # [x1, y1. x2, y2, label]
xmin, ymin, xmax, ymax, label = gtbox_label[:, 0], gtbox_label[:, 1], gtbox_label[:, 2], gtbox_label[:, 3], \
gtbox_label[:, 4]
gtbox_label = np.transpose(np.stack([ymin, xmin, ymax, xmax, label], axis=0)) # [ymin, xmin, ymax, xmax, label]
return img_height, img_width, gtbox_label
数据中包括非目标的box
该问题的影响在于不能使用bndbox = objects.getElementsByTagName('bndbox')
来获取坐标,原因在于该语句会将object
之下的part
部分的box也读取进来,而非只读取object
下的box。
错误读取函数代码为demo_xml_read_xy_tag_wrong.py
.
为了不读取part
部分的数据,我们使用if child_of_root.tag == 'object':
以及for child_item in child_of_root:
实现仅对object tag下的box进行读取。
数据不严格按照(xmin, ymin, xmax, ymax)进行存放
使用if child_of_root.tag == 'object':
以及for child_item in child_of_root:
实现仅对object tag下的box进行读取时,同时依据box下的node的node.tag
来对xmin, ymin, xmax, ymax
分别进行赋值。
实现代码为demo_work_xml_read.py
import xml.etree.cElementTree as ET
import numpy as np
import cv2
import copy
import glob
import os
import matplotlib.pyplot as plt
NAME_LABEL_MAP = {
'back_ground': 0,
'aeroplane': 1,
'bicycle': 2,
'bird': 3,
'boat': 4,
'bottle': 5,
'bus': 6,
'car': 7,
'cat': 8,
'chair': 9,
'cow': 10,
'diningtable': 11,
'dog': 12,
'horse': 13,
'motorbike': 14,
'person': 15,
'pottedplant': 16,
'sheep': 17,
'sofa': 18,
'train': 19,
'tvmonitor': 20
}
# parameters
xml_path = "./xml"
img_format = ".jpg"
image_path = "./image"
def read_xml_gtbox_and_label(xml_path):
"""
:param xml_path: the path of voc xml
:return: a list contains gtboxes and labels, shape is [num_of_gtboxes, 5],
and has [xmin, ymin, xmax, ymax, label] in a per row
"""
tree = ET.parse(xml_path)
root = tree.getroot()
img_width = None
img_height = None
box_list = []
for child_of_root in root:
# if child_of_root.tag == 'filename':
# assert child_of_root.text == xml_path.split('/')[-1].split('.')[0] \
# + FLAGS.img_format, 'xml_name and img_name cannot match'
if child_of_root.tag == 'size':
for child_item in child_of_root:
if child_item.tag == 'width':
img_width = int(child_item.text)
if child_item.tag == 'height':
img_height = int(child_item.text)
if child_of_root.tag == 'object':
label = None
for child_item in child_of_root:
if child_item.tag == 'name':
label = NAME_LABEL_MAP[child_item.text]
if child_item.tag == 'bndbox':
tmp_box = []
for node in child_item:
if node.tag == 'xmin':
x1 = int(node.text)
elif node.tag == 'xmax':
x2 = int(node.text)
elif node.tag == 'ymin':
y1 = int(node.text)
elif node.tag == 'ymax':
y2 = int(node.text)
else:
raise ValueError("invalid tag name in bndbox")
tmp_box = [x1, y1, x2, y2] # [x1, y1. x2, y2]
print("tmp_box: ", tmp_box)
# x1 = tmp_box[0]
# y1 = tmp_box[1]
# x2 = tmp_box[2]
# y2 = tmp_box[3]
if x1 >= x2 or y1 >= y2:
print(xml_path)
with open("disabled_data.txt", 'a') as f:
f.writelines(xml_path)
f.writelines('\n')
assert label is not None, 'label is none, error'
tmp_box.append(label) # [x1, y1. x2, y2, label]
box_list.append(tmp_box)
gtbox_label = np.array(box_list, dtype=np.int32) # [x1, y1. x2, y2, label]
xmin, ymin, xmax, ymax, label = gtbox_label[:, 0], gtbox_label[:, 1], gtbox_label[:, 2], gtbox_label[:, 3], \
gtbox_label[:, 4]
gtbox_label = np.transpose(np.stack([ymin, xmin, ymax, xmax, label], axis=0)) # [ymin, xmin, ymax, xmax, label]
return img_height, img_width, gtbox_label
if __name__ == '__main__':
for count, xml in enumerate(glob.glob(xml_path + '/*.xml')):
# to avoid path error in different development platform
xml = xml.replace('\\', '/')
img_name = xml.split('/')[-1].split('.')[0] + img_format
img_path = image_path + '/' + img_name
if not os.path.exists(img_path):
print('{} is not exist!'.format(img_path))
continue
img_height, img_width, gtbox_label = read_xml_gtbox_and_label(xml)
# visualize boxes in img:
img = cv2.imread(img_path)
img_show = copy.deepcopy(img)
for bbox_and_label in gtbox_label:
xmin = bbox_and_label[0]
ymin = bbox_and_label[1]
xmax = bbox_and_label[2]
ymax = bbox_and_label[3]
label = bbox_and_label[4]
cv2.rectangle(img_show, (ymin, xmin), (ymax, xmax), (0, 255, 0), 2)
plt.figure()
plt.imshow(img_show, 'brg')
plt.show()
注意,这里的 清洗 并不改变原始数据,而是改变了数据读取的代码。由于是网络训练时由数据引起的问题,所以放在了数据清洗部分。