Pytorch-YOLOv3数据集制作流程

github链接:https://github.com/eriklindernoren/PyTorch-YOLOv3

使用Pytorch-YOLOv3训练自己的数据集

1. 生成网络结构文件yolov3_custom.cfg

cd config/
bash create_custom_model.sh <num-classes>  #num-classes is the num of classes
ps:利用kmeans对数据集进行聚类,生成anchor,修改yolov3_custom.cfg中的anchor参数设置

2. 生成classes.names文件

Add class names to data/custom/classes.names. This file should have one row per class name.

3. 生成image图片

Move the images of your dataset to data/custom/images/

4. 生成标注文件

Move your annotations to data/custom/labels/

One txt corresponds to a jpg

Each row in the annotation file should define one bounding box, using the syntax label_idx 
x_center y_center width height. The coordinates should be scaled [0, 1](Divide by width 
and height), and the label_idx should be zero-indexed and correspond to the row number of 
the class name in data/custom/classes.names.
import sys,os
import cv2
import numpy as np
from string import Template
import argparse
try:
    import xml.etree.cElementTree as ET
except ImportError:
    import xml.etree.ElementTree as ET

def GetLabel_id(ClassNameFile):
    """
    get label id according to the classnamefile
    :param ClassNameFile:
    :return:
    """

    Label_dic = {}
    Label_id = 0

    fp = open(ClassNameFile,'r')
    lines = fp.readlines()
    for line in lines:
        line = line[:-1]
        Label_dic[line] = Label_id
        Label_id += 1
    fp.close()

    return Label_dic

def xml2txt(XmlPath, TxtPath, ClassNameFile):
    """
    Convert dataset annotations from xml format to txt format
    :param XmlPath:
    :param TxtPath:
    :param ClassNameFile:
    :return:
    """

    # get label id
    Label_dic = GetLabel_id(ClassNameFile)

    # txt label template
    s = Template("$label_id $xcenter $ycenter $width $height\n")

    # xml to txt
    xmllist = os.listdir(XmlPath)
    for xmlfile in xmllist:
        fp = open(os.path.join(TxtPath,xmlfile[:-3]+'txt'),'w')
        tree = ET.ElementTree(file=os.path.join(XmlPath,xmlfile))
        root = tree.getroot()
        ObjectSet = root.findall('object')
        ObjBndBoxSet = {}
        width = float(root.find('size').find('width').text)
        height = float(root.find('size').find('height').text)

        for Object in ObjectSet:
            ObjName = Object.find('name').text
            BndBox = Object.find('bndbox')
            x1 = float(BndBox.find('xmin').text)
            y1 = float(BndBox.find('ymin').text)
            x2 = float(BndBox.find('xmax').text)
            y2 = float(BndBox.find('ymax').text)

            # get txt format
            label_id = Label_dic[ObjName]
            xcenter = (x1 + x2)/2.0/width
            ycenter = (y1 + y2)/2.0/height
            w = (x2 - x1)/width
            h = (y2 - y1)/height
            fp.write(s.substitute(label_id = label_id,xcenter = xcenter, ycenter = ycenter,
            width = w, height = h))
        fp.close()

5. 划分训练集和测试集

In data/custom/train.txt and data/custom/valid.txt, add paths to images that will be used as train
and validation data respectively.
import sys,os
import cv2
import argparse
import numpy as np

def split_train_val(ImgPath, TrainRatio, TxtPath):
    """
    get train.txt and val.txt
    :param ImgPath:
    :param TrainRatio:
    :param TxtPath:
    :return:
    """
    fp_train = open(os.path.join(TxtPath,'train.txt'),'w')
    fp_val = open(os.path.join(TxtPath,'valid.txt'),'w')
    imglist = os.listdir(ImgPath)
    imgnum = len(imglist)
    rand = np.arange(imgnum)
    np.random.shuffle(rand)
    print(rand)
    trainnum = int(imgnum * TrainRatio)
    valnum = int(imgnum - trainnum)
    for i in range(trainnum):
        fp_train.write(os.path.join(ImgPath,imglist[rand[i]]) + '\n')
    for i in range(valnum):
        fp_val.write(os.path.join(ImgPath,imglist[rand[i+trainnum]]) + '\n')
    fp_train.close()
    fp_val.close()

6. 设置训练参数

--model_def config/yolov3-custom.cfg
--data_config --data_config config/custom.data
--pretrained_weights  weights/darknet53.conv.74

python train.py --model_def config/yolov3-custom.cfg --data_config config/custom.data --pretrained_weights weights/darknet53.conv.74


不显示warning信息
python -W ignore train.py --model_def config/yolov3-custom.cfg --data_config config/custom.data --pretrained_weights weights/darknet53.conv.74
发布了38 篇原创文章 · 获赞 8 · 访问量 1万+

猜你喜欢

转载自blog.csdn.net/u014090429/article/details/102497445