pytorch加载pascal&&coco数据集

上一篇博客https://blog.csdn.net/goodxin_ie/article/details/84315458我们详细介绍了pascal&&coco数据集,本篇我们将介绍pytorch如何加载

一、目标

pascal数据集的数据源是jpg图片,便签是xml文件,而pytorch运算使用的数据是Tensor。因此我们的目标是将jpg和xml文件转化为可供程序运算使用的Tensor或者numpy类型(Tesnor和numpy可以相互转化)。

回忆一下目标检测算法需要的标签信息,有类别和bbox框。在pascal数据集中,每张图片中的对象由xml中的objec标定,每个对象存在类别名name,位置框('ymin', 'xmin', 'ymax', 'xmax'),是否为困难样本的标记difficult。

二、解析xml文件

调用ElementTree元素树可以很方便的解析出xml文件的各种信息。我们主要使用其中的find方法查找对应属性的信息

ET.findall('object')   #查找对象
ET.findall('bndbox')   #查找位置框

完整的解析pasacal中xml文件代码如下:

输入参数:路径,文件名,是否使用困难样本

输出: bbox,label,difficult   (类型np.float32)

def parseXml(data_dir,id,use_difficult=False):
        anno = ET.parse(
            os.path.join(data_dir, 'Annotations', id + '.xml'))
        bbox = list()
        label = list()
        difficult = list()
        for obj in anno.findall('object'):
            if not use_difficult and int(obj.find('difficult').text) == 1:
                continue
            difficult.append(int(obj.find('difficult').text))
            bndbox_anno = obj.find('bndbox')

            bbox.append([
                int(bndbox_anno.find(tag).text) - 1
                for tag in ('ymin', 'xmin', 'ymax', 'xmax')])
            name = obj.find('name').text.lower().strip()
            label.append(VOC_BBOX_LABEL_NAMES.index(name))
        bbox = np.stack(bbox).astype(np.float32)     #from list to array
        label = np.stack(label).astype(np.int32)

        difficult = np.array(difficult, dtype=np.bool).astype(np.uint8)  # PyTorch don't support np.bool
        return  bbox, label, difficult

猜你喜欢

转载自blog.csdn.net/goodxin_ie/article/details/84317276