class VOCDetection(Dataset):
def __init__(
self,
data_dir,
image_sets=[('2007', 'trainval'), ('2012', 'trainval')],
img_size=(416, 416),
preproc=None,
target_transform=AnnotationTransform(),
dataset_name="VOC0712",
custom = True # 新增
):
super().__init__(img_size)
self.root = data_dir
self.image_set = image_sets
self.img_size = img_size
self.preproc = preproc
self.target_transform = target_transform
self.name = dataset_name
self._annopath = os.path.join("%s", "Annotations", "%s.xml")
self._imgpath = os.path.join("%s", "JPEGImages", "%s.jpg")
self._classes = VOC_CLASSES
self.ids = list()
self.custom = custom
if self.custom: # 处理自己的数据集
self.base_dir,self.custom_name = image_sets[0] # DATA_NAME
rootpath = os.path.join(self.root, self.base_dir)
for line in open(
os.path.join(rootpath, "ImageSets", "Main", self.custom_name + ".txt")
):
self.ids.append((rootpath, line.strip()))
else: # 处理默认的Voc数据集
for (year, name) in image_sets:
self._year = year
rootpath = os.path.join(self.root, "VOC" + year)
for line in open(
os.path.join(rootpath, "ImageSets", "Main", name + ".txt")
):
self.ids.append((rootpath, line.strip()))
http://www.eepw.com.cn/zhuanlan/209850.html