实现Faster R-CNN的keras代码理解(一)-VOC数据解析

版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/xiaomifanhxx/article/details/83376363

    本章节主要讲述了对Pascal VOC数据集的信息分析,将标注好的xml文件内容存储到annotation_data的数组中,以便于后面进行读取,进行检测与分类,代码解析也写到了代码里面,方便查看。

import os
import cv2
import xml.etree.ElementTree as ET##解析xml文件的编译器
import numpy as np
def get_data(input_path):#解析文件的路径
	all_imgs = []

	classes_count = {}

	class_mapping = {}

	visualise = False

	data_paths = [os.path.join(input_path,s) for s in ['VOC2007', 'VOC2012']]##os.path.join()进行路径拼接
	

	print('Parsing annotation files')
	##对每个路径的图片进行分析,并将图片的内容保存到all_imgs中
	for data_path in data_paths:

		annot_path = os.path.join(data_path, 'Annotations')#存放.xml文件
		imgs_path = os.path.join(data_path, 'JPEGImages')#存放.jpg文件
		imgsets_path_trainval = os.path.join(data_path, 'ImageSets','Main','trainval.txt')#存放训练集图片的名称
		imgsets_path_test = os.path.join(data_path, 'ImageSets','Main','test.txt')#存放测试集图片的名称

		trainval_files = []
		test_files = []
		try:
			with open(imgsets_path_trainval) as f:
				for line in f:
					trainval_files.append(line.strip() + '.jpg')
		except Exception as e:
			print(e)

		try:
			with open(imgsets_path_test) as f:
				for line in f:
					test_files.append(line.strip() + '.jpg')
		except Exception as e:
			if data_path[-7:] == 'VOC2012':
				# this is expected, most pascal voc distibutions dont have the test.txt file
				pass
			else:
				print(e)
		
		annots = [os.path.join(annot_path, s) for s in os.listdir(annot_path)]#存放所有图片xml文件的路径
		idx = 0
		for annot in annots:
			try:
				idx += 1

				et = ET.parse(annot)
				element = et.getroot()

				element_objs = element.findall('object')
				element_filename = element.find('filename').text
				element_width = int(element.find('size').find('width').text)
				element_height = int(element.find('size').find('height').text)

				if len(element_objs) > 0:
					annotation_data = {'filepath': os.path.join(imgs_path, element_filename), 'width': element_width,
									   'height': element_height, 'bboxes': []}#存放图片的基本信息

					if element_filename in trainval_files:
						annotation_data['imageset'] = 'trainval'
					elif element_filename in test_files:
						annotation_data['imageset'] = 'test'
					else:
						annotation_data['imageset'] = 'trainval'

				for element_obj in element_objs:
					class_name = element_obj.find('name').text
					if class_name not in classes_count:
						classes_count[class_name] = 1
					else:
						classes_count[class_name] += 1

					if class_name not in class_mapping:
						class_mapping[class_name] = len(class_mapping)##标记

					obj_bbox = element_obj.find('bndbox')
					x1 = int(round(float(obj_bbox.find('xmin').text)))
					y1 = int(round(float(obj_bbox.find('ymin').text)))
					x2 = int(round(float(obj_bbox.find('xmax').text)))
					y2 = int(round(float(obj_bbox.find('ymax').text)))
					difficulty = int(element_obj.find('difficult').text) == 1
					annotation_data['bboxes'].append(
						{'class': class_name, 'x1': x1, 'x2': x2, 'y1': y1, 'y2': y2, 'difficult': difficulty})
				all_imgs.append(annotation_data)

				if visualise:
					img = cv2.imread(annotation_data['filepath'])
					for bbox in annotation_data['bboxes']:
						cv2.rectangle(img, (bbox['x1'], bbox['y1']), (bbox[
									  'x2'], bbox['y2']), (0, 0, 255))
					cv2.imshow('img', img)
					cv2.waitKey(0)

			except Exception as e:
				print(e)
				continue
	return all_imgs, classes_count, class_mapping

该函数返回参数:

all_imgs:存储的是每一张图片的内容,filename,width,height,imageSet(训练集/测试集),Bbox(类别、坐标以及difficult);

classes_count:存储的是训练集中,每一类的总数量;

class_mapping:存储的是一个字典key:value代表着:class:类别

猜你喜欢

转载自blog.csdn.net/xiaomifanhxx/article/details/83376363