版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/xiaomifanhxx/article/details/83376363
本章节主要讲述了对Pascal VOC数据集的信息分析,将标注好的xml文件内容存储到annotation_data的数组中,以便于后面进行读取,进行检测与分类,代码解析也写到了代码里面,方便查看。
import os
import cv2
import xml.etree.ElementTree as ET##解析xml文件的编译器
import numpy as np
def get_data(input_path):#解析文件的路径
all_imgs = []
classes_count = {}
class_mapping = {}
visualise = False
data_paths = [os.path.join(input_path,s) for s in ['VOC2007', 'VOC2012']]##os.path.join()进行路径拼接
print('Parsing annotation files')
##对每个路径的图片进行分析,并将图片的内容保存到all_imgs中
for data_path in data_paths:
annot_path = os.path.join(data_path, 'Annotations')#存放.xml文件
imgs_path = os.path.join(data_path, 'JPEGImages')#存放.jpg文件
imgsets_path_trainval = os.path.join(data_path, 'ImageSets','Main','trainval.txt')#存放训练集图片的名称
imgsets_path_test = os.path.join(data_path, 'ImageSets','Main','test.txt')#存放测试集图片的名称
trainval_files = []
test_files = []
try:
with open(imgsets_path_trainval) as f:
for line in f:
trainval_files.append(line.strip() + '.jpg')
except Exception as e:
print(e)
try:
with open(imgsets_path_test) as f:
for line in f:
test_files.append(line.strip() + '.jpg')
except Exception as e:
if data_path[-7:] == 'VOC2012':
# this is expected, most pascal voc distibutions dont have the test.txt file
pass
else:
print(e)
annots = [os.path.join(annot_path, s) for s in os.listdir(annot_path)]#存放所有图片xml文件的路径
idx = 0
for annot in annots:
try:
idx += 1
et = ET.parse(annot)
element = et.getroot()
element_objs = element.findall('object')
element_filename = element.find('filename').text
element_width = int(element.find('size').find('width').text)
element_height = int(element.find('size').find('height').text)
if len(element_objs) > 0:
annotation_data = {'filepath': os.path.join(imgs_path, element_filename), 'width': element_width,
'height': element_height, 'bboxes': []}#存放图片的基本信息
if element_filename in trainval_files:
annotation_data['imageset'] = 'trainval'
elif element_filename in test_files:
annotation_data['imageset'] = 'test'
else:
annotation_data['imageset'] = 'trainval'
for element_obj in element_objs:
class_name = element_obj.find('name').text
if class_name not in classes_count:
classes_count[class_name] = 1
else:
classes_count[class_name] += 1
if class_name not in class_mapping:
class_mapping[class_name] = len(class_mapping)##标记
obj_bbox = element_obj.find('bndbox')
x1 = int(round(float(obj_bbox.find('xmin').text)))
y1 = int(round(float(obj_bbox.find('ymin').text)))
x2 = int(round(float(obj_bbox.find('xmax').text)))
y2 = int(round(float(obj_bbox.find('ymax').text)))
difficulty = int(element_obj.find('difficult').text) == 1
annotation_data['bboxes'].append(
{'class': class_name, 'x1': x1, 'x2': x2, 'y1': y1, 'y2': y2, 'difficult': difficulty})
all_imgs.append(annotation_data)
if visualise:
img = cv2.imread(annotation_data['filepath'])
for bbox in annotation_data['bboxes']:
cv2.rectangle(img, (bbox['x1'], bbox['y1']), (bbox[
'x2'], bbox['y2']), (0, 0, 255))
cv2.imshow('img', img)
cv2.waitKey(0)
except Exception as e:
print(e)
continue
return all_imgs, classes_count, class_mapping
该函数返回参数:
all_imgs:存储的是每一张图片的内容,filename,width,height,imageSet(训练集/测试集),Bbox(类别、坐标以及difficult);
classes_count:存储的是训练集中,每一类的总数量;
class_mapping:存储的是一个字典key:value代表着:class:类别