ET.parse读取xml文件(基于VOC2012数据集的代码)

目录

VOC2012数据集

图片类别

xml文件的读取

code


VOC2012数据集

链接:https://pan.baidu.com/s/1uV5j6BEkwd8yKLUhaUPzPQ?pwd=aaaa 
提取码:aaaa

数据集目录:(共包含10张图片)

        其中Annotations为10张图片的label(xml文件), ImageSets-main中的txt文档为10张图片的名字,JPEGImages为10张图片。

图片类别

        21个类别,类别名见CLASS_NAME,通过zip函数将类别名编号,分别对应序号0-20,转化为字典形式。

"""类别字典的创建 class_name:序号 """
CLASSES_NAME = (
        "__background__ ",
        "aeroplane",
        "bicycle",
        "bird",
        "boat",
        "bottle",
        "bus",
        "car",
        "cat",
        "chair",
        "cow",
        "diningtable",
        "dog",
        "horse",
        "motorbike",
        "person",
        "pottedplant",
        "sheep",
        "sofa",
        "train",
        "tvmonitor",
    )
name2id =dict(zip(CLASSES_NAME,range(len(CLASSES_NAME))))

xml文件的读取

        xml文件只能从根节点往下一步一步地遍历。ET为元素树方法,ET.parse读取label,通过.getroot()获取根节点anno。

tag:标签,用于标识该元素表示哪种数据

attrib:属性,用字典形式保存

text:文本字符串,通过 .find(节点).text查看节点的内容

code

import xml.etree.ElementTree as ET
import os
import numpy as np

"""类别字典的创建 class_name:序号 """
CLASSES_NAME = (
        "__background__ ",
        "aeroplane",
        "bicycle",
        "bird",
        "boat",
        "bottle",
        "bus",
        "car",
        "cat",
        "chair",
        "cow",
        "diningtable",
        "dog",
        "horse",
        "motorbike",
        "person",
        "pottedplant",
        "sheep",
        "sofa",
        "train",
        "tvmonitor",
    )
name2id =dict(zip(CLASSES_NAME,range(len(CLASSES_NAME))))

def get_xml_label(label_path):
    """从xml文件中获得label"""

    anno = ET.parse(label_path).getroot()  # .getroot()获取根节点
    # for node in anno:  # 子树
    #     print(node.tag,node.attrib)  # 节点名称以及节点属性(含object物体)

    boxes = []
    classes = []
    for obj in anno.iter("object"):  # 迭代object的子节点
        # for i in obj:
        #     print(i)  # object的子节点含:name pose truncated occluded bndbox difficult

        # 放弃难分辨的图片
        difficult = int(obj.find("difficult").text) == 1
        if difficult:
            continue
        
        # bounding box坐标值的查找
        _box = obj.find("bndbox")   
        box = [
            _box.find("xmin").text,
            _box.find("ymin").text,
            _box.find("xmax").text,
            _box.find("ymax").text,
        ]

        # 框像素点位置-1(python从0开始)
        TO_REMOVE = 1
        box = tuple(
            map(lambda x: x - TO_REMOVE, list(map(float, box)))
        )
        boxes.append(box)

        # 框对应的类别序号
        name = obj.find("name").text.lower().strip()  # 类别名称,统一为小写,并且去除左右空格以及换行符
        classes.append(name2id[name])  # 序号

    boxes = np.array(boxes, dtype=np.float32)
    return boxes, classes

label_path=os.path.join(r'D:\VOC2012\Annotations','%s.xml')  # %s指待输入的字符串
boxes,classes=get_xml_label(label_path %'2008_000007')  
print(boxes)  
print(classes)

猜你喜欢

转载自blog.csdn.net/m0_63077499/article/details/127901105
今日推荐