El conjunto de datos de detección de formato VOC o formato COCO extrae clases específicas

Prefacio

A veces necesitamos extraer ciertas clases del conjunto de datos etiquetados para el entrenamiento. Tomando como ejemplo la anotación del conjunto de datos COCO común y el formato del conjunto de datos VOC, este artículo proporciona métodos de extracción de clases específicos para los dos formatos de conjuntos de datos, que también son disponible en línea. Muchos contenidos similares se pueden resumir y registrar para que se puedan encontrar fácilmente cuando se utilicen más adelante.

1. Extraiga clases específicas del conjunto de datos en formato COCO

# COCO数据集提取某个类或者某些类

from pycocotools.coco import COCO
import os
import shutil
from tqdm import tqdm
import matplotlib.pyplot as plt
import cv2
from PIL import Image, ImageDraw

# 需要设置的路径
savepath = "D:\BaiduNetdiskDownload\coco\COCO/car/"
img_dir = savepath + 'images/'
anno_dir = savepath + 'annotations/'
datasets_list = ['train2017']

# coco有80类,这里写要提取类的名字,以car为例
classes_names = ['car','bus','truck']
# 包含所有类别的原coco数据集路径
'''
目录格式如下:
$COCO_PATH
----|annotations
----|train2017
----|val2017
----|test2017
'''
dataDir = 'D:\BaiduNetdiskDownload\coco\COCO/'

headstr = """\
<annotation>
    <folder>VOC</folder>
    <filename>%s</filename>
    <source>
        <database>My Database</database>
        <annotation>COCO</annotation>
        <image>flickr</image>
        <flickrid>NULL</flickrid>
    </source>
    <owner>
        <flickrid>NULL</flickrid>
        <name>company</name>
    </owner>
    <size>
        <width>%d</width>
        <height>%d</height>
        <depth>%d</depth>
    </size>
    <segmented>0</segmented>
"""
objstr = """\
    <object>
        <name>%s</name>
        <pose>Unspecified</pose>
        <truncated>0</truncated>
        <difficult>0</difficult>
        <bndbox>
            <xmin>%d</xmin>
            <ymin>%d</ymin>
            <xmax>%d</xmax>
            <ymax>%d</ymax>
        </bndbox>
    </object>
"""

tailstr = '''\
</annotation>
'''


# 检查目录是否存在,如果存在,先删除再创建,否则,直接创建
def mkr(path):
    if not os.path.exists(path):
        os.makedirs(path)  # 可以创建多级目录


def id2name(coco):
    classes = dict()
    for cls in coco.dataset['categories']:
        classes[cls['id']] = cls['name']
    return classes


def write_xml(anno_path, head, objs, tail):
    f = open(anno_path, "w")
    f.write(head)
    for obj in objs:
        f.write(objstr % (obj[0], obj[1], obj[2], obj[3], obj[4]))
    f.write(tail)


def save_annotations_and_imgs(coco, dataset, filename, objs):
    # 将图片转为xml,例:COCO_train2017_000000196610.jpg-->COCO_train2017_000000196610.xml
    dst_anno_dir = os.path.join(anno_dir, dataset)
    mkr(dst_anno_dir)
    anno_path = dst_anno_dir + '/' + filename[:-3] + 'xml'
    img_path = dataDir + dataset + '/' + filename
    print("img_path: ", img_path)
    dst_img_dir = os.path.join(img_dir, dataset)
    mkr(dst_img_dir)
    dst_imgpath = dst_img_dir + '/' + filename
    print("dst_imgpath: ", dst_imgpath)
    img = cv2.imread(img_path)
    # if (img.shape[2] == 1):
    #    print(filename + " not a RGB image")
    #   return
    shutil.copy(img_path, dst_imgpath)

    head = headstr % (filename, img.shape[1], img.shape[0], img.shape[2])
    tail = tailstr
    write_xml(anno_path, head, objs, tail)


def showimg(coco, dataset, img, classes, cls_id, show=True):
    global dataDir
    I = Image.open('%s/%s/%s' % (dataDir, dataset, img['file_name']))
    # 通过id,得到注释的信息
    annIds = coco.getAnnIds(imgIds=img['id'], catIds=cls_id, iscrowd=None)
    # print(annIds)
    anns = coco.loadAnns(annIds)
    # print(anns)
    # coco.showAnns(anns)
    objs = []
    for ann in anns:
        class_name = classes[ann['category_id']]
        if class_name in classes_names:
            # print(class_name)
            if 'bbox' in ann:
                bbox = ann['bbox']
                xmin = int(bbox[0])
                ymin = int(bbox[1])
                xmax = int(bbox[2] + bbox[0])
                ymax = int(bbox[3] + bbox[1])
                obj = [class_name, xmin, ymin, xmax, ymax]
                objs.append(obj)
                draw = ImageDraw.Draw(I)
                draw.rectangle([xmin, ymin, xmax, ymax])
    if show:
        plt.figure()
        plt.axis('off')
        plt.imshow(I)
        plt.show()

    return objs


for dataset in datasets_list:
    # ./COCO/annotations/instances_train2017.json
    annFile = '{}/annotations/instances_{}.json'.format(dataDir, dataset)

    # 使用COCO API用来初始化注释数据
    coco = COCO(annFile)

    # 获取COCO数据集中的所有类别
    classes = id2name(coco)
    # print(classes)
    # [1, 2, 3, 4, 6, 8]
    classes_ids = coco.getCatIds(catNms=classes_names)
    # print(classes_ids)
    for cls in classes_names:
        # 获取该类的id
        cls_id = coco.getCatIds(catNms=[cls])
        img_ids = coco.getImgIds(catIds=cls_id)
        # print(cls, len(img_ids))
        # imgIds=img_ids[0:10]
        for imgId in tqdm(img_ids):
            img = coco.loadImgs(imgId)[0]
            filename = img['file_name']
            # print(filename)
            objs = showimg(coco, dataset, img, classes, classes_ids, show=False)
            # print(objs)
            save_annotations_and_imgs(coco, dataset, filename, objs)

2. Extraiga clases específicas del conjunto de datos en formato VOC

# VOC数据集提取某个类或者某些类
# !/usr/bin/env python
# -*- encoding: utf-8 -*-
import os
import xml.etree.ElementTree as ET
import shutil

# 根据自己的情况修改相应的路径
ann_filepath = r'Annotations/'
img_filepath = r'JPEGImages/'
img_savepath = r'imgs/'
ann_savepath = r'xmls/'
if not os.path.exists(img_savepath):
    os.mkdir(img_savepath)

if not os.path.exists(ann_savepath):
    os.mkdir(ann_savepath)

# 这是VOC数据集中所有类别
# classes = ['aeroplane', 'bicycle', 'bird', 'boat', 'bottle',
#             'bus', 'car', 'cat', 'chair', 'cow', 'diningtable',
#              'dog', 'horse', 'motorbike', 'pottedplant',
#           'sheep', 'sofa', 'train', 'person','tvmonitor']

classes = ['car']  # 这里是需要提取的类别

def save_annotation(file):
    tree = ET.parse(ann_filepath + '/' + file)
    root = tree.getroot()
    result = root.findall("object")
    bool_num = 0
    for obj in result:
        if obj.find("name").text not in classes:
            root.remove(obj)
        else:
            bool_num = 1
    if bool_num:
        tree.write(ann_savepath + file)
        return True
    else:
        return False

def save_images(file):
    name_img = img_filepath + os.path.splitext(file)[0] + ".png"
    shutil.copy(name_img, img_savepath)
    # 文本文件名自己定义,主要用于生成相应的训练或测试的txt文件
    with open('train.txt', 'a') as file_txt:
        file_txt.write(os.path.splitext(file)[0])
        file_txt.write("\n")
    return True


if __name__ == '__main__':
    for f in os.listdir(ann_filepath):
        print(f)
        if save_annotation(f):
            save_images(f)

3. Modifique el nombre de una determinada clase en el conjunto de datos en formato VOC.

#!/usr/bin/env python2
# -*- coding: utf-8 -*-

import os
import xml.etree.ElementTree as ET

origin_ann_dir = r'xmls/'  # 设置原始标签路径为 Annos
new_ann_dir = r'xmls/'  # 设置新标签路径 Annotations
for dirpaths, dirnames, filenames in os.walk(origin_ann_dir):  # os.walk游走遍历目录名
    for filename in filenames:
        print("process...")
        if os.path.isfile(r'%s%s' % (origin_ann_dir, filename)):  # 获取原始xml文件绝对路径,isfile()检测是否为文件 isdir检测是否为目录
            origin_ann_path = os.path.join(r'%s%s' % (origin_ann_dir, filename))  # 如果是,获取绝对路径(重复代码)
            new_ann_path = os.path.join(r'%s%s' % (new_ann_dir, filename))
            tree = ET.parse(origin_ann_path)  # ET是一个xml文件解析库,ET.parse()打开xml文件。parse--"解析"
            root = tree.getroot()  # 获取根节点
            for object in root.findall('object'):  # 找到根节点下所有“object”节点
                name = str(object.find('name').text)  # 找到object节点下name子节点的值(字符串)
                # 功能1.删除指定类别的标签。如果name等于str,则删除该节点
                # if (name in ["car_head"]):
                #   root.remove(object)

                # 功能2.修改指定类别的标签。如果name等于str,则修改name
                if (name in ["car","bus","truck"]):           # 将car bus truck三个类改成car类
                    object.find('name').text = "car"

            # # 功能3.删除labelmap中没有的标签。检查是否存在labelmap中没有的类别
            # for object in root.findall('object'):
            #   name = str(object.find('name').text)
            #   if not (name in ["chepai","chedeng","chebiao","person"]):
            #       print(filename + "------------->label is error--->" + name)

            # # 功能4.比对xml中filename名称与图片名称是否一致。如果xml中filename名称与文件名称不一致,则对其进行修改
            # get_name = str(root.find('filename').text)
            # if filename.replace(".xml", ".jpg") != get_name:
            #       print("{}-->name is inconformity!".format(filename))
            #       root.find('filename').text = filename.replace(".xml", ".jpg")
            # else:
            #   continue

            tree.write(new_ann_path)  # tree为文件,write写入新的文件中。


Supongo que te gusta

Origin blog.csdn.net/qq_39056987/article/details/125082396
Recomendado
Clasificación