yolov5 的检测结果存为LabelMe的格式

1. LabelMe标注文件的格式

       一个标注文件的格式如下,和VOC的标注格式很像。

<?xml version="1.0" ?>
<annotation>
  <filename>0a8f1803ae7a0d65c5dd5561167e6a30</filename>
  <folder></folder>
  <source>
    <sourceImage></sourceImage>
    <sourceAnnotation>Datumaro</sourceAnnotation>
  </source>
  <imagesize>
    <nrows>1920</nrows>
    <ncols>1080</ncols>
  </imagesize>
  <object>
    <name>行人框(属性)</name>
    <deleted>0</deleted>
    <verified>0</verified>
    <occluded>no</occluded>
    <date></date>
    <id>0</id>
    <parts>
      <hasparts></hasparts>
      <ispartof></ispartof>
    </parts>
    <type>bounding_box</type>
    <polygon>
      <pt>
        <x>1104</x>
        <y>243</y>
      </pt>
      <pt>
        <x>1291</x>
        <y>605</y>
      </pt>
      <user_name></user_name>
    </polygon>
    <attributes></attributes>
  </object>
  <object>
    <name>行人框(属性)</name>
    <deleted>0</deleted>
    <verified>0</verified>
    <occluded>no</occluded>
    <date></date>
    <id>1</id>
    <parts>
      <hasparts></hasparts>
      <ispartof></ispartof>
    </parts>
    <type>bounding_box</type>
    <polygon>
      <pt>
        <x>1639</x>
        <y>241</y>
      </pt>
      <pt>
        <x>1709</x>
        <y>334</y>
      </pt>
      <user_name></user_name>
    </polygon>
    <attributes></attributes>
  </object>
</annotation>

2. 转换代码

"""
author:guopei
date:2020.02.23
"""
import os
from PIL import Image,ImageDraw,ImageFont
from tqdm import tqdm
import xml.dom.minidom

import cv2
import numpy as np
import torch
import torch.backends.cudnn as cudnn

from models.experimental import attempt_load
from utils.general import letterbox, non_max_suppression, scale_coords


class Yolov5Detect(object):
    def __init__(self, weights='./weights/yolov5x.pt', device=0, img_size=800, conf=0.65, iou=0.5):
        with torch.no_grad():
            self.device = "cuda:%s" % device
            self.model = attempt_load(weights, map_location=self.device) # load FP32 model
            self.model.half() # to FP16
            self.imgsz = img_size  # img_size最好是32的整数倍
            self.conf = conf
            self.iou = iou
            temp_img = torch.zeros((1, 3, self.imgsz, self.imgsz), device=self.device)  # init img
            _ = self.model(temp_img.half())  # run once

    def pre_process(self, img_path):
        img0 = cv2.imread(img_path)
        assert img0 is not None, "Image Not Found " + img_path
        img = letterbox(img0, new_shape=self.imgsz)[0]
        # Convert
        img = img[:, :, ::-1].transpose(2, 0, 1)  # BGR to RGB, to 3x416x416
        img = np.ascontiguousarray(img)
        return img, img0

    def predict(self, img_path):
        img, img0 = self.pre_process(img_path)
        img = torch.from_numpy(img).to(self.device)
        img = img.half()  # uint8 to fp16
        img /= 255.0  # 0 - 255 to 0.0 - 1.0
        if img.ndimension() == 3:
            img = img.unsqueeze(0)

        # Inference
        pred = self.model(img, augment=False)[0]
        return pred, img, img0


    def post_process(self, img_path):
        pred, img, img0 = self.predict(img_path)

        # Apply NMS
        pred = non_max_suppression(pred, self.conf, self.iou, classes=None, agnostic=False)
        pred, im0 = pred[0], img0
        if pred is not None and len(pred):
            pred[:, :4] = scale_coords(img.shape[2:], pred[:, :4], im0.shape).round()
            pred = pred.cpu().detach().numpy().tolist() # from tensor to list
        return pred, img0



def get_image_list(image_dir, suffix=['jpg', 'jpeg', 'JPG', 'JPEG','png']):
    '''get all image path ends with suffix'''
    if not os.path.exists(image_dir):
        print("PATH:%s not exists" % image_dir)
        return []
    imglist = []
    for root, sdirs, files in os.walk(image_dir):
        if not files:
            continue
        for filename in files:
            filepath = os.path.join(root, filename)
            if filename.split('.')[-1] in suffix:
                imglist.append(filepath)
    return imglist


def CreatXml(imgPath, results, xmlPath):
    img = cv2.imread(imgPath)
    imgSize = img.shape
    imgName = imgPath.split('/')[-1]

    impl = xml.dom.minidom.getDOMImplementation()
    dom = impl.createDocument(None, 'annotation', None)
    root = dom.documentElement

    filename = dom.createElement('filename')
    root.appendChild(filename)
    name_img = dom.createTextNode(os.path.splitext(imgName)[0])
    filename.appendChild(name_img)

    folder = dom.createElement('folder')
    root.appendChild(folder)
    foldername = dom.createTextNode('')
    folder.appendChild(foldername)

    source = dom.createElement('source')
    root.appendChild(source)
    database_img = dom.createElement('sourceImage')
    img_source = dom.createTextNode('')
    database_img.appendChild(img_source)
    database_anno = dom.createElement('sourceAnnotation')
    database_name = dom.createTextNode('Datumaro')
    database_anno.appendChild(database_name)
    source.appendChild(database_img)
    source.appendChild(database_anno)

    img_size = dom.createElement('imagesize')
    root.appendChild(img_size)
    width = dom.createElement('nrows')
    width_num = dom.createTextNode(str(int(imgSize[1])))
    width.appendChild(width_num)
    height = dom.createElement('ncols')
    height_num = dom.createTextNode(str(int(imgSize[0])))
    height.appendChild(height_num)
    img_size.appendChild(width)
    img_size.appendChild(height)

    for i in range(len(results)):
        img_object = dom.createElement('object')
        root.appendChild(img_object)
        
        label = dom.createElement('name')
        label_name = dom.createTextNode(results[i]['class'])
        label.appendChild(label_name)
        
        dele = dom.createElement('deleted')
        dele_name = dom.createTextNode('0')
        dele.appendChild(dele_name)
        
        ver = dom.createElement('verified')
        ver_name = dom.createTextNode('0')
        ver.appendChild(ver_name)
        
        ocl = dom.createElement('occluded')
        ocl_name = dom.createTextNode('no')
        ocl.appendChild(ocl_name)
        
        date = dom.createElement('date')
        date_name = dom.createTextNode('')
        date.appendChild(date_name)
  
        face_id = dom.createElement('id')
        face_id_name = dom.createTextNode(str(i))
        face_id.appendChild(face_id_name)
        
      
        parts = dom.createElement('parts')
        pt_node_1 = dom.createElement('hasparts')
        node1_name = dom.createTextNode('')
        pt_node_1.appendChild(node1_name)
        pt_node_2 = dom.createElement('ispartof')
        node2_name = dom.createTextNode('')
        pt_node_2.appendChild(node2_name)
        parts.appendChild(pt_node_1)
        parts.appendChild(pt_node_2)
        
    
        face_type = dom.createElement('type')
        face_type_name = dom.createTextNode('bounding_box')
        face_type.appendChild(face_type_name)

        bndbox = dom.createElement('polygon')
        left_top = dom.createElement('pt')
        x_top = dom.createElement('x')
        x_top_val = dom.createTextNode(str(int(results[i]['bbox'][0])))
        x_top.appendChild(x_top_val)
        y_top = dom.createElement('y')
        y_top_val = dom.createTextNode(str(int(results[i]['bbox'][1])))
        y_top.appendChild(y_top_val)
        left_top.appendChild(x_top)
        left_top.appendChild(y_top)

        
        right_bottom = dom.createElement('pt')
        x_bottom = dom.createElement('x')
        x_bottom_val = dom.createTextNode(str(int(results[i]['bbox'][2])))
        x_bottom.appendChild(x_bottom_val)
        y_bottom = dom.createElement('y')
        y_bottom_val = dom.createTextNode(str(int(results[i]['bbox'][3])))
        y_bottom.appendChild(y_bottom_val)
        right_bottom.appendChild(x_bottom)
        right_bottom.appendChild(y_bottom)
        
        user_name = dom.createElement('user_name')
        u_name = dom.createTextNode('')
        user_name.appendChild(u_name)
        
        bndbox.appendChild(left_top)
        bndbox.appendChild(right_bottom)
        bndbox.appendChild(user_name)
        face_attri = dom.createElement('attributes')
        attri_val = dom.createTextNode('')
        face_attri.appendChild(attri_val)
        
        img_object.appendChild(label)
        img_object.appendChild(dele)
        img_object.appendChild(ver)
        img_object.appendChild(ocl)
        img_object.appendChild(date)
        img_object.appendChild(face_id)
        img_object.appendChild(parts)
        img_object.appendChild(face_type)
        img_object.appendChild(bndbox)
        img_object.appendChild(face_attri)

    f = open(xmlPath, 'w')
    dom.writexml(f, addindent='  ', newl='\n')
    f.close()


if __name__ == '__main__':
    detector = Yolov5Detect()
    img_list = get_image_list("imgs")
    #img_list = ["test.jpg"]
    for img_path in tqdm(img_list):
        pred, img0 = detector.post_process(img_path)
        if pred is None:
            continue
        # 筛选出person, person的标签为0
        pred = [i for i in pred if i[-1]==0.0]
        if pred is None:
            continue
        objects = []
        for obj in pred:
            result = {}
            x1, y1, x2, y2, conf, label = obj
            result['class'] = "行人框(属性)"
            result['bbox'] = [int(x1), int(y1), int(x2), int(y2)]
            objects.append(result)

        CreatXml(img_path, objects, os.path.join("xmls", os.path.basename(img_path).replace(".jpg", ".xml")))

注:代码比较清晰,拿去直接用就行了。

猜你喜欢

转载自blog.csdn.net/Guo_Python/article/details/114287214
今日推荐