Python batch data processing script - data set coco format to yolo's darknet format (labelme)

coco—>voc,voc—>darknet,coco—>darknet

Because I mostly get json and darknet formats. The coco format data set is not very often downloaded and used, and the voc one has never been used, so this article briefly introduces the code directly, and merges the coco to voc and voc to darknet together. So if you are one of these three needs, this article will be useful to you.

coco—>voc

import os
import time
import json

import numpy as np
import pandas as pd
from tqdm import tqdm
from pycocotools.coco import COCO


def trans_id(category_id):
    names = []
    namesid = []
    for i in range(0, len(cats)):
        names.append(cats[i]['name'])
        namesid.append(cats[i]['id'])
        # print('id:{1}\t {0}'.format(names[i], namesid[i]))
    index = namesid.index(category_id)
    return index


# root = 'D:\\val\\thermal_8_bit\\'  # 你下载的 COCO 数据集所在目录
# dataType = '2019'
anno = '/home/alpha/桌面/22222/Safety/_annotations.coco.json'
xml_dir = '/home/alpha/桌面/22222/Safety/json_xml/'

coco = COCO(anno)  # 读文件
cats = coco.loadCats(coco.getCatIds())  # 这里loadCats就是coco提供的接口,获取类别

# Create anno dir
dttm = time.strftime("%Y%m%d%H%M%S", time.localtime())
# if os.path.exists(xml_dir):
#     os.rename(xml_dir, xml_dir + dttm)
# os.mkdir(xml_dir)

with open(anno, 'r') as load_f:
    f = json.load(load_f)

imgs = f['images']  # json文件的img_id和图片对应关系 imgs列表表示多少张图

cat = f['categories']
df_cate = pd.DataFrame(f['categories'])  # json中的类别
df_cate_sort = df_cate.sort_values(["id"], ascending=True)  # 按照类别id排序
categories = list(df_cate_sort['name'])  # 获取所有类别名称
print('categories = ', categories)
df_anno = pd.DataFrame(f['annotations'])  # json中的annotation

for i in tqdm(range(len(imgs))):  # 大循环是images所有图片
    xml_content = []
    file_name = imgs[i]['file_name']  # 通过img_id找到图片的信息
    height = imgs[i]['height']
    img_id = imgs[i]['id']
    width = imgs[i]['width']

    # xml文件添加属性
    xml_content.append("<annotation>")
    xml_content.append("	<folder>VOC2007</folder>")
    xml_content.append("	<filename>" + file_name + "</filename>")
    xml_content.append("	<size>")
    xml_content.append("		<width>" + str(width) + "</width>")
    xml_content.append("		<height>" + str(height) + "</height>")
    xml_content.append("	</size>")
    xml_content.append("	<segmented>0</segmented>")

    # 通过img_id找到annotations
    annos = df_anno[df_anno["image_id"].isin([img_id])]  # (2,8)表示一张图有两个框

    for index, row in annos.iterrows():  # 一张图的所有annotation信息
        bbox = row["bbox"]
        category_id = row["category_id"]
        # cate_name = categories[trans_id(category_id)]
        cate_name = cat[category_id]['name']

        # add new object
        xml_content.append("<object>")
        xml_content.append("<name>" + cate_name + "</name>")
        xml_content.append("<pose>Unspecified</pose>")
        xml_content.append("<truncated>0</truncated>")
        xml_content.append("<difficult>0</difficult>")
        xml_content.append("<bndbox>")
        xml_content.append("<xmin>" + str(int(bbox[0])) + "</xmin>")
        xml_content.append("<ymin>" + str(int(bbox[1])) + "</ymin>")
        xml_content.append("<xmax>" + str(int(bbox[0] + bbox[2])) + "</xmax>")
        xml_content.append("<ymax>" + str(int(bbox[1] + bbox[3])) + "</ymax>")
        xml_content.append("</bndbox>")
        xml_content.append("</object>")
    xml_content.append("</annotation>")

    x = xml_content
    xml_content = [x[i] for i in range(0, len(x)) if x[i] != "\n"]
    ### list存入文件
    xml_path = os.path.join(xml_dir, file_name.replace('.jpg', '.xml'))
    with open(xml_path, 'w+', encoding="utf8") as f:
        f.write('\n'.join(xml_content))
    xml_content[:] = []

voc—>darknet

import argparse
import glob
import os
import xml.etree.ElementTree as ET
import json
from tqdm import tqdm

def parse_args():
    """
        参数配置
    """
    parser = argparse.ArgumentParser(description='xml2json')
    parser.add_argument('--raw_label_dir', help='the path of raw label', default='/home/alpha/桌面/22222/Safety/json_xml')   # voc路径
    parser.add_argument('--pic_dir', help='the path of picture', default='/home/alpha/桌面/22222/Safety/valid')     # 图片路径
    parser.add_argument('--save_dir', help='the path of new label', default='/home/alpha/桌面/22222/Safety/json')      # 保存路径
    args = parser.parse_args()
    return args

def read_xml_gtbox_and_label(xml_path):
    """
        读取xml内容
    """

    tree = ET.parse(xml_path)
    root = tree.getroot()
    size = root.find('size')
    width = int(size.find('width').text)
    height = int(size.find('height').text)
    # depth = int(size.find('depth').text)
    points = []
    for obj in root.iter('object'):
        cls = obj.find('name').text
        pose = obj.find('pose').text
        xmlbox = obj.find('bndbox')
        xmin = float(xmlbox.find('xmin').text)
        xmax = float(xmlbox.find('xmax').text)
        ymin = float(xmlbox.find('ymin').text)
        ymax = float(xmlbox.find('ymax').text)
        box = [xmin, ymin, xmax, ymax]
        point = [cls, box]
        points.append(point)
    return points, width, height

def main():
    """
        主函数
    """
    args = parse_args()
    labels = glob.glob(args.raw_label_dir + '/*.xml')
    for i, label_abs in tqdm(enumerate(labels), total=len(labels)):
        _, label = os.path.split(label_abs)
        label_name = label.rstrip('.xml')
        # img_path = os.path.join(args.pic_dir, label_name + '.jpg')
        img_path = label_name + '.jpg'
        points, width, height = read_xml_gtbox_and_label(label_abs)
        json_str = {
    
    }
        json_str['version'] = '4.5.6'
        json_str['flags'] = {
    
    }
        shapes = []
        for i in range(len(points)):
        	# 判断是否是左下角的点为关键点
            if points[i][0] == "left head":
                shape = {
    
    }
                shape['label'] = 'head'
                shape['points'] = [[points[i][1][0], points[i][1][3]]]
                shape['group_id'] = None
                # 类型为点
                shape['shape_type'] = 'point'
                shape['flags'] = {
    
    }
                shapes.append(shape)
            # 判断是否是右下角的点是关键点
            elif points[i][0] == "right head":
                shape = {
    
    }
                shape['label'] = 'head'
                shape['points'] = [[points[i][1][2], points[i][1][3]]]
                shape['group_id'] = None
                shape['shape_type'] = 'point'
                shape['flags'] = {
    
    }
                shapes.append(shape)
            # 其余的情况
            else:
                shape = {
    
    }
                shape['label'] = points[i][0]
                shape['points'] = [[points[i][1][0], points[i][1][1]],
                                    [points[i][1][2], points[i][1][3]]]
                shape['group_id'] = None
                # labelIMG的标注类型基本都为长方形
                shape['shape_type'] = 'rectangle'
                shape['flags'] = {
    
    }
                shapes.append(shape)
        json_str['shapes'] = shapes
        json_str['imagePath'] = img_path
        json_str['imageData'] = None
        json_str['imageHeight'] = height
        json_str['imageWidth'] = width
        with open(os.path.join(args.save_dir, label_name + '.json'), 'w') as f:
            json.dump(json_str, f, indent=2)

if __name__ == '__main__':
    main()

Guess you like

Origin blog.csdn.net/weixin_45354497/article/details/130654044