Convert voc dataset to coco dataset

voc2coco.py

# coding:utf-8

# pip install lxml

import os
import glob
import json
import shutil
import numpy as np
import xml.etree.ElementTree as ET
import cv2
import random
import shutil
import re

cocopath = r"../cocodataset/"

START_BOUNDING_BOX_ID = 1


class MyEncoder(json.JSONEncoder):
    def default(self, obj):
        if isinstance(obj, np.integer):
            return int(obj)
        elif isinstance(obj, np.floating):
            return float(obj)
        elif isinstance(obj, np.ndarray):
            return obj.tolist()
        else:
            return super(MyEncoder, self).default(obj)


def check_jpg_xml_file(path):
    jpgList = os.listdir(path + '/JPEGImages/')
    xmlList = os.listdir(path + '/Annotations/')

    for i in jpgList:
        jpg_name = i.split('.')[0]
        flag = False
        for j in xmlList:
            xml_name = j.split('.')[0]
            if jpg_name == xml_name:
                flag = True
        if not flag:
            print('多余文件:' + path + '/JPEGImages/' + jpg_name + '.jpg')
            os.remove(path + '/JPEGImages/' + jpg_name + '.jpg')
        flag = False

    for i1 in xmlList:
        xml_name = i.split('.')[0]
        flag = False
        for j1 in jpgList:
            jpg_name = j.split('.')[0]
            if jpg_name == xml_name:
                flag = True
        if not flag:
            print('多余文件:' + path + '/Annotations/' + xml_name + '.xml')
            os.remove(path + '/Annotations/' + xml_name + '.xml')
        flag = False


def jianqu_prefix(path, prefix_name):
    """
    批量修改文件件下的文件名减去前缀
    """
    jpgList = os.listdir(path + '/JPEGImages/')
    xmlList = os.listdir(path + '/Annotations/')

    n = 0
    for i in jpgList:
        # 设置旧文件名(就是路径+文件名)
        oldname = path + '/JPEGImages' + os.sep + jpgList[n]  # os.sep添加系统分隔符
        # 设置新文件名
        newname = path + '/JPEGImages' + os.sep + jpgList[n].split('}')[1]
        os.rename(oldname, newname)  # 用os模块中的rename方法对文件改名
        n += 1

    m = 0
    for i in xmlList:
        # 设置旧文件名(就是路径+文件名)
        oldname = path + '/Annotations' + os.sep + xmlList[m]  # os.sep添加系统分隔符
        # 设置新文件名
        newname = path + '/Annotations' + os.sep + xmlList[m].split('}')[1]
        os.rename(oldname, newname)  # 用os模块中的rename方法对文件改名
        m += 1


def add_prefix(path, prefix_name):
    """
    批量修改文件件下的文件名增加前缀
    """
    jpgList = os.listdir(path + '/JPEGImages/')
    xmlList = os.listdir(path + '/Annotations/')

    n = 0
    for i in jpgList:
        # 设置旧文件名(就是路径+文件名)
        oldname = path + '/JPEGImages' + os.sep + jpgList[n]  # os.sep添加系统分隔符
        # 设置新文件名
        newname = path + '/JPEGImages' + os.sep + prefix_name + '{' + str(n) + '}' + jpgList[n]
        os.rename(oldname, newname)  # 用os模块中的rename方法对文件改名
        n += 1

    m = 0
    for i in xmlList:
        # 设置旧文件名(就是路径+文件名)
        oldname = path + '/Annotations' + os.sep + xmlList[m]  # os.sep添加系统分隔符
        # 设置新文件名
        newname = path + '/Annotations' + os.sep + prefix_name + '{' + str(m) + '}' + xmlList[m]
        os.rename(oldname, newname)  # 用os模块中的rename方法对文件改名
        m += 1


def get(root, name):
    return root.findall(name)


def get_and_check(root, name, length):
    vars = root.findall(name)
    if len(vars) == 0:
        raise NotImplementedError('Can not find %s in %s.' % (name, root.tag))
    if length > 0 and len(vars) != length:
        raise NotImplementedError('The size of %s is supposed to be %d, but is %d.' % (name, length, len(vars)))
    if length == 1:
        vars = vars[0]
    return vars


def convert(pre_define_categories, pre_define_categories_numbers, xml_list, json_file):
    json_dict = {"images": [], "type": "instances", "annotations": [], "categories": []}
    categories = pre_define_categories.copy()
    categories_num = pre_define_categories_numbers.copy()
    bnd_id = START_BOUNDING_BOX_ID
    all_categories = {}
    for index, line in enumerate(xml_list):
        # print("Processing %s"%(line))
        xml_f = line
        tree = ET.parse(xml_f)
        root = tree.getroot()

        filename = os.path.basename(xml_f)[:-4] + ".jpg"
        image_id = 20200904001 + index

        xmlname = xml_f.split(".xml")[0]
        jpgname = xmlname + ".jpg"
        img = cv2.imread(jpgname)
        if img is None:
            print(os.path.abspath('..') + jpgname.split('..')[1])
            print('错误:' + jpgname)
            try:
                os.remove(os.path.abspath('..') + jpgname.split('..')[1])
            except Exception as e:
                print(e)
            try:
                os.remove(os.path.abspath('..') + (xmlname + '.xml').split('..')[1])
            except Exception as e:
                print(e)

            continue

        # print(xmlname)
        # print(jpgname)
        if jpgname == '../train_val/train/20210118_data2000_{1145}230.jpg':
            print('')
        width = img.shape[1]
        height = img.shape[0]

        image = {'file_name': filename, 'height': height, 'width': width, 'id': image_id}

        ## Cruuently we do not support segmentation
        #  segmented = get_and_check(root, 'segmented', 1).text
        #  assert segmented == '0'
        valide = True
        for obj in get(root, 'object'):
            category = get_and_check(obj, 'name', 1).text
            if category in all_categories:
                all_categories[category] += 1
            else:
                all_categories[category] = 1
            if category not in categories:
                if only_care_pre_define_categories:
                    continue
                new_id = len(categories) + 1
                print(
                    "[warning] category '{}' not in 'pre_define_categories'({}), create new id: {} automatically".format(
                        category, pre_define_categories, new_id))
                categories[category] = new_id
            categories_num[category] += 1
            category_id = categories[category]
            bndbox = get_and_check(obj, 'bndbox', 1)
            xmin = int(float(get_and_check(bndbox, 'xmin', 1).text))
            ymin = int(float(get_and_check(bndbox, 'ymin', 1).text))
            xmax = int(float(get_and_check(bndbox, 'xmax', 1).text))
            ymax = int(float(get_and_check(bndbox, 'ymax', 1).text))
            if xmax <= xmin:
                valide = False
                continue
            if ymax <= ymin:
                valide = False
                continue
            # assert(xmax > xmin), "xmax <= xmin, {}".format(line)
            # assert(ymax > ymin), "ymax <= ymin, {}".format(line)
            o_width = abs(xmax - xmin)
            o_height = abs(ymax - ymin)
            ann = {'area': o_width * o_height, 'iscrowd': 0, 'image_id':
                image_id, 'bbox': [xmin, ymin, o_width, o_height],
                   'category_id': category_id, 'id': bnd_id, 'ignore': 0,
                   'segmentation': [xmin, ymin, xmax, ymin, xmax, ymax, xmin, ymax]}
            json_dict['annotations'].append(ann)
            bnd_id = bnd_id + 1
        if valide:
            json_dict['images'].append(image)

    for cate, cid in categories.items():
        cat = {'supercategory': 'none', 'id': cid, 'name': cate}
        json_dict['categories'].append(cat)
    json_fp = open(json_file, 'w')
    json_str = json.dumps(json_dict, indent=4, cls=MyEncoder)
    # json.dumps(json_dict,open(json_file, 'w'))
    # json.dump(self.data_coco, open(self.save_json_path, 'w'), indent=4, cls=MyEncoder)
    json_fp.write(json_str)
    json_fp.close()
    print("------------create {} done--------------".format(json_file))
    print("find {} categories: {} -->>> your pre_define_categories {}: {}".format(len(all_categories),
                                                                                  all_categories.keys(),
                                                                                  len(pre_define_categories),
                                                                                  pre_define_categories.keys()))
    print("category: id --> {}".format(categories))
    print(categories.keys())
    print(categories.values())
    print("all categories : {}".format(all_categories))


def date_shape(source_path, dest_path, train_percent, trainval_percent):
    """
    path:根路径
    train_percent:训练集比例
    """
    total_xml = os.listdir(source_path + '/Annotations')
    num = len(total_xml)
    list = range(num)
    tv = int(num * trainval_percent)
    tr = int(tv * train_percent)
    trainval = random.sample(list, tv)
    train = random.sample(trainval, tr)
    ftrainval = open(source_path + '/ImageSets/trainval.txt', 'w')
    ftest = open(source_path + '/ImageSets/test.txt', 'w')
    ftrain = open(source_path + '/ImageSets/train.txt', 'w')
    fval = open(source_path + '/ImageSets/val.txt', 'w')

    ftrainval_1 = open(source_path + '/ImageSets/trainval_1.txt', 'w')
    ftest_1 = open(source_path + '/ImageSets/test_1.txt', 'w')
    ftrain_1 = open(source_path + '/ImageSets/train_1.txt', 'w')
    fval_1 = open(source_path + '/ImageSets/val_1.txt', 'w')

    try:
        for i in list:
            name = total_xml[i][:-4] + '\n'
            name_xml = source_path + '/Annotations/' + total_xml[i][:-4] + '.xml' + '\n'
            xml_file = source_path + '/Annotations/' + total_xml[i][:-4] + '.xml'
            name_jpg = source_path + '/JPEGImages/' + total_xml[i][:-4] + '.jpg' + '\n'
            jpg_file = source_path + '/JPEGImages/' + total_xml[i][:-4] + '.jpg'
            if i in trainval:
                ftrainval.write(name)
                ftrainval_1.write(name_xml)
                ftrainval_1.write(name_jpg)
                shutil.copy(xml_file, dest_path + '/val/')
                shutil.copy(jpg_file, dest_path + '/val/')
                if i in train:
                    ftest.write(name)
                    ftest_1.write(name_xml)
                    ftest_1.write(name_jpg)
                else:
                    fval.write(name)
                    fval_1.write(name_xml)
                    fval_1.write(name_jpg)
            else:
                ftrain.write(name)
                ftrain_1.write(name_xml)
                ftrain_1.write(name_jpg)
                shutil.copy(xml_file, dest_path + '/train/')
                shutil.copy(jpg_file, dest_path + '/train/')

    except Exception as e:
        print('不存在:' + name)

    ftrainval.close()
    ftrain.close()
    fval.close()
    ftest.close()

    pass


def get_train_val_data(source_path, dest_path):
    # source_path = 'D:/ai/data/mask/新的数据集'
    # source_path = 'D:/ai/data/mask/data2000/data2000'
    # dest_path = 'D:/ai/data/result/train_val'
    trainval_percent = 0.2
    train_percent = 0.8
    date_shape(source_path, dest_path, train_percent, trainval_percent)


def get_coco_dataset():
    classes = ['mask', 'nomask']
    # classes = [
    #     'face',
    #     'face_mask']
    pre_define_categories = {}
    pre_define_categories_numbers = {}
    for i, cls in enumerate(classes):
        pre_define_categories[cls] = i + 1
        pre_define_categories_numbers[cls] = 0
    # pre_define_categories = {'a1': 1, 'a3': 2, 'a6': 3, 'a9': 4, "a10": 5}
    only_care_pre_define_categories = True
    # only_care_pre_define_categories = False

    # train_ratio = 0.9
    save_json_train = cocopath + 'annotations/instances_train.json'
    save_json_val = cocopath + 'annotations/instances_val.json'
    # save_json_test = cocopath+'annotations/test.json'
    # xml_dir = "./test"

    # xml_list = glob.glob(r"resize/*.xml")
    # xml_list = np.sort(xml_list)
    # np.random.seed(100)
    # n#p.random.shuffle(xml_list)

    # train_num = int(len(xml_list)*train_ratio)
    # xml_list_train = xml_list[:train_num]
    xml_list_train = glob.glob(r"../train_val/train/*.xml")
    # xml_list_val = xml_list[train_num:]
    xml_list_val = glob.glob(r"../train_val/val/*.xml")
    # xml_list_test = glob.glob(r"dataset/test/*.xml")
    if os.path.exists(cocopath + "annotations"):
        shutil.rmtree(cocopath + "annotations")
    os.makedirs(cocopath + "annotations")
    convert(pre_define_categories, pre_define_categories_numbers, xml_list_train, save_json_train)
    convert(pre_define_categories, pre_define_categories_numbers, xml_list_val, save_json_val)
    # convert(xml_list_test, save_json_test)

    if os.path.exists(cocopath + "train"):
        shutil.rmtree(cocopath + "train")
    os.makedirs(cocopath + "train")
    if os.path.exists(cocopath + "val"):
        shutil.rmtree(cocopath + "val")
    os.makedirs(cocopath + "val")

    if os.path.exists(cocopath + "test"):
        shutil.rmtree(cocopath + "test")
    os.makedirs(cocopath + "test")

    f1 = open(cocopath + "train.txt", "w")
    for xml in xml_list_train:
        img = xml[:-4] + ".jpg"
        f1.write(os.path.basename(xml)[:-4] + "\n")
        shutil.copyfile(img, cocopath + "train/" + os.path.basename(img))

    f2 = open(cocopath + "val.txt", "w")
    for xml in xml_list_val:
        img = xml[:-4] + ".jpg"
        f2.write(os.path.basename(xml)[:-4] + "\n")
        shutil.copyfile(img, cocopath + "val/" + os.path.basename(img))

    #    f3 = open(cocopath+"test.txt", "w")
    #    for xml in xml_list_test:
    #        img = xml[:-4] + ".jpg"
    #        f3.write(os.path.basename(xml)[:-4] + "\n")
    #        shutil.copyfile(img, cocopath + "test/" + os.path.basename(img))
    f1.close()
    f2.close()
    # f3.close()
    print("-------------------------------")
    print("train number:", len(xml_list_train))
    print("val number:", len(xml_list_val))


# print("test number:", len(xml_list_test))

def change_str(path):
    """
    替换文件夹下所有xml文件中包含 no mask 的字符串为 nomask
    """
    str_pattern = r"no mask"
    str_new = r"nomask"
    path_list = os.listdir(path)
    for file in path_list:
        abs_path = os.path.join(path, file)
        if os.path.isfile(abs_path):
            if re.search('(.xml)', file):
                print(abs_path)
                with open(abs_path, 'r', encoding="utf-8") as f:
                    str_all = f.read()
                with open(abs_path, 'w', encoding="utf-8") as f:
                    f.write(re.sub(str_pattern, str_new, str_all))
                # 修改文件名

        else:
            change_str(abs_path)


if __name__ == '__main__':
    # 判断jpg和xml的数量
    # check_jpg_xml_file('D:/ai/data/mask/新的数据集')
    # 替换文件夹下所有xml文件中包含 no mask 的字符串为 nomask
    # change_str('D:/ai/data/mask/data2000/data2000/Annotations')
    # 去掉前缀
    # jianqu_prefix('D:/ai/data/mask/新的数据集', '20210118_newdata_')
    # 增加前缀
    # add_prefix('D:/ai/data/mask/新的数据集', '20210118_newdata_')
    # add_prefix('D:/ai/data/mask/data2000/data2000', '20210118_data2000_')
    # 按照一定比例划分训练集测试集,并复制到trainval文件夹中
    # get_train_val_data('D:/ai/data/mask/data2000/data2000', 'D:/ai/data/result/train_val')
    # get_train_val_data('D:/ai/data/mask/新的数据集', 'D:/ai/data/result/train_val')
    # 生成coco数据集
    get_coco_dataset()
    pass

Guess you like

Origin blog.csdn.net/qq122716072/article/details/112786958