[DETR] DETR training VOC data set / own data set

1. Data preparation

DETR uses the COCO format dataset.
If you want to use DETR to train your own data set, directly use Labelimg to mark it into COCO format.
If it is a VOC dataset, a format conversion is required. A lot of format conversion codes on the Internet are very messy, so I wrote a conversion for the VOC dataset.


The format of the COCO dataset is similar to this, and there are json files of the corresponding train and val datasets in the annotations folder. train2017 is the training set picture, and the others are the same.
insert image description here
The storage method of the VOC data set is like this. The conversion format is to find the pictures used for target detection under the Main folder.
insert image description here
There is a train.txt file under the Main folder, which records the pictures of the training set. val.txt records the pictures of the verification set.
insert image description here
You only need to modify the two paths in the comments (the judgment statement is not added when creating the folder, and it should be more rigorous).

import os
import shutil
import sys
import json
import glob
import xml.etree.ElementTree as ET


START_BOUNDING_BOX_ID = 1
# PRE_DEFINE_CATEGORIES = None
# If necessary, pre-define category and its id
PRE_DEFINE_CATEGORIES = {
    
    "aeroplane": 1, "bicycle": 2, "bird": 3, "boat": 4,
                         "bottle": 5, "bus": 6, "car": 7, "cat": 8, "chair": 9,
                         "cow": 10, "diningtable": 11, "dog": 12, "horse": 13,
                         "motorbike": 14, "person": 15, "pottedplant": 16,
                         "sheep": 17, "sofa": 18, "train": 19, "tvmonitor": 20}


def get(root, name):
    vars = root.findall(name)
    return vars


def get_and_check(root, name, length):
    vars = root.findall(name)
    if len(vars) == 0:
        raise ValueError("Can not find %s in %s." % (name, root.tag))
    if length > 0 and len(vars) != length:
        raise ValueError(
            "The size of %s is supposed to be %d, but is %d."
            % (name, length, len(vars))
        )
    if length == 1:
        vars = vars[0]
    return vars


def get_filename_as_int(filename):
    try:
        filename = filename.replace("\\", "/")
        filename = os.path.splitext(os.path.basename(filename))[0]
        return int(filename)
    except:
        raise ValueError(
            "Filename %s is supposed to be an integer." % (filename))


def get_categories(xml_files):
    """Generate category name to id mapping from a list of xml files.

    Arguments:
        xml_files {list} -- A list of xml file paths.

    Returns:
        dict -- category name to id mapping.
    """
    classes_names = []
    for xml_file in xml_files:
        tree = ET.parse(xml_file)
        root = tree.getroot()
        for member in root.findall("object"):
            classes_names.append(member[0].text)
    classes_names = list(set(classes_names))
    classes_names.sort()
    return {
    
    name: i for i, name in enumerate(classes_names)}


def convert(xml_files, json_file):
    json_dict = {
    
    "images": [], "type": "instances",
                 "annotations": [], "categories": []}
    if PRE_DEFINE_CATEGORIES is not None:
        categories = PRE_DEFINE_CATEGORIES
    else:
        categories = get_categories(xml_files)
    bnd_id = START_BOUNDING_BOX_ID
    for xml_file in xml_files:
        tree = ET.parse(xml_file)
        root = tree.getroot()
        path = get(root, "path")
        if len(path) == 1:
            filename = os.path.basename(path[0].text)
        elif len(path) == 0:
            filename = get_and_check(root, "filename", 1).text
        else:
            raise ValueError("%d paths found in %s" % (len(path), xml_file))
        # The filename must be a number
        image_id = get_filename_as_int(filename)
        size = get_and_check(root, "size", 1)
        width = int(get_and_check(size, "width", 1).text)
        height = int(get_and_check(size, "height", 1).text)
        image = {
    
    
            "file_name": filename,
            "height": height,
            "width": width,
            "id": image_id,
        }
        json_dict["images"].append(image)
        # Currently we do not support segmentation.
        #  segmented = get_and_check(root, 'segmented', 1).text
        #  assert segmented == '0'
        for obj in get(root, "object"):
            category = get_and_check(obj, "name", 1).text
            if category not in categories:
                new_id = len(categories)
                categories[category] = new_id
            category_id = categories[category]
            bndbox = get_and_check(obj, "bndbox", 1)
            xmin = int(get_and_check(bndbox, "xmin", 1).text) - 1
            ymin = int(get_and_check(bndbox, "ymin", 1).text) - 1
            xmax = int(get_and_check(bndbox, "xmax", 1).text)
            ymax = int(get_and_check(bndbox, "ymax", 1).text)
            assert xmax > xmin
            assert ymax > ymin
            o_width = abs(xmax - xmin)
            o_height = abs(ymax - ymin)
            ann = {
    
    
                "area": o_width * o_height,
                "iscrowd": 0,
                "image_id": image_id,
                "bbox": [xmin, ymin, o_width, o_height],
                "category_id": category_id,
                "id": bnd_id,
                "ignore": 0,
                "segmentation": [],
            }
            json_dict["annotations"].append(ann)
            bnd_id = bnd_id + 1

    for cate, cid in categories.items():
        cat = {
    
    "supercategory": "none", "id": cid, "name": cate}
        json_dict["categories"].append(cat)

    os.makedirs(os.path.dirname(json_file), exist_ok=True)
    json_fp = open(json_file, "w")
    json_str = json.dumps(json_dict)
    json_fp.write(json_str)
    json_fp.close()


if __name__ == "__main__":
    #  只需修改以下两个路径
    #  VOC数据集根目录
    voc_path = "VOC2012"
    
    #  保存coco格式数据集根目录
    save_coco_path = "VOC2COCO"
    
    #  VOC只分了训练集和验证集即train.txt和val.txt
    data_type_list = ["train", "val"]
    for data_type in data_type_list:
        os.makedirs(os.path.join(save_coco_path, data_type+"2017"))
        os.makedirs(os.path.join(save_coco_path, data_type+"_xml"))
        with open(os.path.join(voc_path, "ImageSets\Main", data_type+".txt"), "r") as f:
            txt_ls = f.readlines()
        txt_ls = [i.strip() for i in txt_ls]
        for i in os.listdir(os.path.join(voc_path, "JPEGImages")):
            if os.path.splitext(i)[0] in txt_ls:
                shutil.copy(os.path.join(voc_path, "JPEGImages", i),
                            os.path.join(save_coco_path, data_type+"2017", i))
                shutil.copy(os.path.join(voc_path, "Annotations", i[:-4]+".xml"), os.path.join(
                    save_coco_path, data_type+"_xml", i[:-4]+".xml"))
        xml_path = os.path.join(save_coco_path, data_type+"_xml")
        xml_files = glob.glob(os.path.join(xml_path, "*.xml"))
        convert(xml_files, os.path.join(save_coco_path,
                "annotations", "instances_"+data_type+"2017.json"))
        shutil.rmtree(xml_path)


As shown in the figure, there are three files under the voc2coco folder:
insert image description here

2. Configure DETR

Modify the parameters and hyperparameters in the main.py file:
insert image description here
it is best not to change this, and set it to coco. To modify models/detr.pythe num_classes of the file (about three hundred lines). Here the author also explained that num_classes is not actually the number of categories, because coco only has 80 categories, and because the id of coco is not continuous, the largest ID of the coco dataset is 90, so the original paper wrote 91 MAX ID +1. For our custom and converted VOC dataset num_classes is the number of categories.
insert image description here


insert image description here
Change coco_path to your own coco path.
insert image description here
Among them, the pre-training weight needs to be modified, coco is 80 types, and the official model cannot be loaded directly. voc is 20 categories. Change num_classes to 21. Pass in the resulting detr_r50_21.pth new weights file.

import torch
pretrained_weights=torch.load('detr-r50-e632da11.pth')
num_classes=21
pretrained_weights["model"]["class_embed.weight"].resize_(num_classes+1,256)
pretrained_weights["model"]["class_embed.bias"].resize_(num_classes+1)
torch.save(pretrained_weights,"detr_r50_%d.pth"%num_classes)

Running log (especially difficult to train):
insert image description here

3. Drawing

There is a plot_utils.py file under the util folder, which can draw loss and mAP curves.
insert image description here
Add the code to the plot_utils.py file and run it:

if __name__ == "__main__":
	# 路径更换为保存输出的eval路径
	# mAP曲线
    files=list(Path("./outputs/eval").glob("*.pth"))
    plot_precision_recall(files)
    plt.show()
    # 路径更换为保存输出的路径
    # 损失曲线
    plot_logs(Path("./output"))
    plt.show()

4. Reasoning

After the training is complete, we will get a checkpoint.pth file, and we can use the model we trained to reason about the picture. The code is as follows:

import numpy as np
from models.detr import build
from PIL import Image
import matplotlib.pyplot as plt
import torch
import torchvision.transforms as transforms

torch.set_grad_enabled(False)
COLORS = [[0.000, 0.447, 0.741], [0.850, 0.325, 0.098], [0.929, 0.694, 0.125],
          [0.494, 0.184, 0.556], [0.466, 0.674, 0.188], [0.301, 0.745, 0.933]]
transform_input = transforms.Compose([transforms.Resize(800),
                                      transforms.ToTensor(),
                                      transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])


def box_cxcywh_to_xyxy(x):
    x_c, y_c, w, h = x.unbind(1)
    b = [(x_c - 0.5 * w), (y_c - 0.5 * h),
         (x_c + 0.5 * w), (y_c + 0.5 * h)]
    return torch.stack(b, dim=1)


def rescale_bboxes(out_bbox, size):
    img_w, img_h = size
    b = box_cxcywh_to_xyxy(out_bbox)
    b = b * torch.tensor([img_w, img_h, img_w, img_h], dtype=torch.float32, device="cuda")
    return b


def plot_results(pil_img, prob, boxes, img_save_path):
    plt.figure(figsize=(16, 10))
    plt.imshow(pil_img)
    ax = plt.gca()
    colors = COLORS * 100
    for p, (xmin, ymin, xmax, ymax), c in zip(prob, boxes.tolist(), colors):
        ax.add_patch(plt.Rectangle((xmin, ymin), xmax - xmin, ymax - ymin,
                                   fill=False, color=c, linewidth=3))
        cl = p.argmax()
        text = f'{
      
      CLASSES[cl]}:      {
      
      p[cl]:0.2f}'
        ax.text(xmin, ymin, text, fontsize=9,
                bbox=dict(facecolor='yellow', alpha=0.5))
    plt.savefig(img_save_path)
    plt.axis('off')
    plt.show()


def main(chenkpoint_path, img_path, img_save_path):
    args = torch.load(chenkpoint_path)['args']
    model = build(args)[0]
    device = "cuda" if torch.cuda.is_available() else "cpu"
    model.to(device)
    # 加载模型参数
    model_data = torch.load(chenkpoint_path)['model']
    model.load_state_dict(model_data)

    model.eval()
    img = Image.open(img_path).convert('RGB')
    size = img.size
    
    inputs = transform_input(img).unsqueeze(0)
    outputs = model(inputs.to(device))
    # 这类最后[0, :, :-1]索引其实是把背景类筛选掉了
    probs = outputs['pred_logits'].softmax(-1)[0, :, :-1]
    # 可修改阈值,只输出概率大于0.7的物体
    keep = probs.max(-1).values > 0.7
    bboxes_scaled = rescale_bboxes(outputs['pred_boxes'][0, keep], size)
    # 保存输出结果
    ori_img = np.array(img)
    plot_results(ori_img, probs[keep], bboxes_scaled, img_save_path)


if __name__ == "__main__":
    CLASSES = ['N/A', "aeroplane", "bicycle", "bird", "boat",
               "bottle", "bus", "car", "cat", "chair",
               "cow", "diningtable", "dog", "horse",
               "motorbike", "person", "pottedplant",
               "sheep", "sofa", "train", "tvmonitor", "background"]
    main(chenkpoint_path="checkpoint.pth", img_path="test.png",
         img_save_path="result2.png")

A few notes:
1. CLASSES is the category name corresponding to our data set, pay attention to the order of your labels must be written correctly. The first category is "N/A"neither background nor foreground, since the index of our converted dataset starts from 1, so the category with index 0 is missing. The background class should be the 21st class with the largest index. In fact, I think the addition of the above "background"is the most rigorous.

insert image description here

2. chenkpoint_path: saved weight file
img_path: test image path
img_save_path: save result path

3. The threshold can be modified. In the paper, only objects with a probability greater than 0.7 are output by default.


The inference effect of the model trained with the VOC dataset:
(There is no bicycle in the VOC dataset, so it cannot be recognized)
insert image description here

5. Some minor bugs

1. Rounding problem

UserWarning: floordiv is deprecated, and its behavior will change in a future version of pytorch. It currently rounds toward 0 (like the 'trunc' function NOT 'floor'). This results in incorrect rounding for negative values. To keep the current behavior , use torch.div(a, b, rounding_mode='trunc'), or for actual floor division, use torch.div(a, b, rounding_mode='floor'). At this time, a function problem caused by a torchinsert image description here
version , reported a warning.
Change line 44 in the models/position_encoding.py file to the following form.
insert image description here

2. The setting problem of num_class

The setting of num_class is discussed in detail on github: How to set num_class
Quote the author's words:
insert image description here
num_class should be set to max_id+1, for example, the above voc2coco data set, the index is from 1 to 20, then num_class should be set to 20+1=21, the class with index 21 is Background class, but because the index starts from 1, the class with index 0 is set as N/Aneither background nor foreground, it should be a missing class. The author gives an example of 4 category IDs are 1, 23, 24, 56, then num_class should be set to 57, and the class with index 57 is the background class. Among them, the missing index values: 0, 2-22, 25-55 should be N/Afilled, all of which are missing classes.

References

VOC2COCO code reference Github
DETR pre-training model

Guess you like

Origin blog.csdn.net/m0_46412065/article/details/128538040