用自己制作的数据集训练yolo

用自己制作的数据集训练yolo

最近在做自己数据集做移动端的目标检测,利用darkflow对yolo进行训练,这应该是第一次写博客,就当做给自己做个总结

本文是仿照voc数据集制作数据

利用labelimg制作数据集:

labelimg的github链接: link.

训练自己的数据集

数据集:包括image 和annotation
1.创建一个配置文件,复制 tiny-yolo-voc.cfg文件
——修改类别:将[region]层(最后一层)中的类更改为要训练的类的数量
——将[convolutional]层(第二层到最后一层)中的滤镜更改为num (classes + 5)。 在我们的例子中,num是5并且类是3,所以5 (3 + 5)= 40,因此过滤器被设置为40。
2.更改label.txt
3.load加载权重文件
4.修改annotation和image路径

在训练过程中如果中断了,想要在原本基础上训练,在darkflow/defaults.py 文件里面可以直接更改load为你训练保存在ckpt里的次数,注意默认设置在保存训练结果最邻近的数个模型(default.py里的keep设置可知)。
tip:建议将keep值改大些,我就是跑到不到200个epoch,loss值出现nan,然而只保存最后十个模型就凉了…

训练:
flow --model cfg/yolo-voc-2c.cfg --load bin/yolov2-tiny.weights --train --annotation /your_annotation_path --dataset /your_images_path --gpu 1.0 --savepb

测试:
flow --model cfg/v2/yolo-voc-2c.cfg --load 17000(在checkpoint可查看) --json

遇见的bug:
1.图片shape不匹配:
排除环境问题,是xml文件中的filename名字不对应
2.此次做的数据集只有500张一个类,训练200个epoch,测试时,图片没有检测框:将阈值设小,能出现检测框,但是置信度为0,可能显示的置信度只保留小数点后两位;

于是做了数据增强:图片加模糊,裁剪,总共2100张图片再训练,效果还行

这里附上自己写的数据增强的code,可同时修改xml文件

import numpy as np
import cv2
import xml.dom.minidom
import matplotlib.pyplot as plt
import shutil
import time
import random
import os
import math
import numpy as np
from skimage.util import random_noise
from skimage import exposure


def scale_transform(img_name,resize_to):a
    img_path = '/mnt/mdisk/pyq/data_aug/generate_dataset/pic/'+ img_name
    xml_path = '/mnt/mdisk/pyq_data_aug/generate_dataset/ann/'+ img_name
    xml_path = xml_path.replace('.jpg','.xml')



    img = cv2.imread(img_path)
    w = img.shape[1]
    h = img.shape[0]
    img = cv2.resize(img, (resize_to,resize_to))
    cv2.imwrite(img_path, img)
    tree = xml.dom.minidom.parse(xml_path)
    root = tree.documentElement
    box = root.getElementsByTagName('bndbox')
    size = root.getElementsByTagName('size')
    xmin = int(float(box[0].childNodes[1].firstChild.data))
    ymin = int(float(box[0].childNodes[3].firstChild.data))
    xmax = int(float(box[0].childNodes[5].firstChild.data))
    ymax = int(float(box[0].childNodes[7].firstChild.data))
    xmin_new = (resize_to / w) * xmin
    ymin_new = (resize_to / h) * ymin
    xmax_new = (resize_to / w) * xmax
    ymax_new = (resize_to / h) * ymax
    box[0].childNodes[1].firstChild.data = str(int(xmin_new))
    box[0].childNodes[3].firstChild.data = str(int(ymin_new))
    box[0].childNodes[5].firstChild.data = str(int(xmax_new))
    box[0].childNodes[7].firstChild.data = str(int(ymax_new))
    size[0].childNodes[1].firstChild.data = str(resize_to)  # new w
    size[0].childNodes[3].firstChild.data = str(resize_to)  # new h
    with open(xml_path, 'w') as file:
        tree.writexml(file)


def random_crop(img_name, crop_num):

    img_path = '/mnt/mdisk/pyq/data_aug/new_dataset/pic/' + img_name
    xml_path = '/mnt/mdisk/pyq/data_aug/new_dataset/ann/' + img_name
    img = cv2.imread(img_path + '.jpg')
    w = img.shape[1]
    h = img.shape[0]
    tree = xml.dom.minidom.parse(xml_path + '.xml')
    root = tree.documentElement
    box = root.getElementsByTagName('bndbox')
    filename = root.getElementsByTagName('filename')
    path = root.getElementsByTagName('path')
    size = root.getElementsByTagName('size')
    xmin = int(float(box[0].childNodes[1].firstChild.data))
    ymin = int(float(box[0].childNodes[3].firstChild.data))
    xmax = int(float(box[0].childNodes[5].firstChild.data))
    ymax = int(float(box[0].childNodes[7].firstChild.data))
    for i in range(crop_num):
        # pic & xml
        bottom = -1
        while (bottom < 0):
            # top = np.random.random_integers(ymin)
            # left = np.random.random_integers(xmin)
            # right = np.random.random_integers(w - xmax)
            # bottom = np.random.random_integers(h - ymax)
            top = np.random.randint(1, ymin+i)
            left = np.random.randint(1, xmin+i)
            right = np.random.randint(1,w-xmax +i)
          #  bottom = np.random.randint(1,h - ymax+i)
            bottom = np.random.random_integers( h -ymax)
            #bottom = int((left + right + xmax - xmin) * (4032 / 3024) - top - ymax + ymin)
            crop_xmin = max(0, xmin - left )
            crop_ymin = max(0, ymin - top)
            crop_xmax = min(0, xmax + right)
            crop_ymax = min(0, ymax + bottom)

        #img_new = img[ymin - top:ymax + bottom, xmin - left:xmax + right, :]
        img_new = img[crop_ymin :crop_ymax, crop_xmin:crop_xmax]
        cv2.imwrite(img_path.replace('new', 'generate') + '_c' + str(i + 1) + '.jpg', img_new)
        size[0].childNodes[1].firstChild.data = str(left + right + xmax - xmin)  # new w
        size[0].childNodes[3].firstChild.data = str(top + bottom + ymax - ymin)  # new h
        box[0].childNodes[1].firstChild.data = str(left)
        box[0].childNodes[3].firstChild.data = str(top)
        box[0].childNodes[5].firstChild.data = str(left + xmax - xmin)
        box[0].childNodes[7].firstChild.data = str(top + ymax - ymin)
        filename[0].firstChild.data = img_name + '_c' + str(i) + '.jpg'
        path[0].firstChild.data = img_path + '_c' + str(i) + '.jpg'
        shutil.copyfile(xml_path + '.xml', xml_path.replace('new', 'generate') + '_c' + str(i + 1) + '.xml')
        with open(xml_path.replace('new', 'generate') + '_c' + str(i + 1) + '.xml', 'w') as file:
            tree.writexml(file)

def img_blur(img_name):
    # pic
    img_path = '/mnt/mdisk/pyq/data_aug/new_dataset/pic/' + img_name
    xml_path = '/mnt/mdisk/pyq/data_aug/new_dataset/ann/' + img_name
    img = cv2.imread(img_path + '.jpg')
    i = 9
    while(i % 2 == 0):
        i = np.random.random_integers(10)
    img = cv2.GaussianBlur(img, ksize=(i,i), sigmaX=0, sigmaY=0)
    cv2.imwrite(img_path.replace('new', 'generate') + '_blur.jpg', img)

    # xml
    tree = xml.dom.minidom.parse(xml_path + '.xml')
    root = tree.documentElement
    filename = root.getElementsByTagName('filename')
    path = root.getElementsByTagName('path')
    filename[0].firstChild.data = filename[0].firstChild.data[:-4] + '_blur.jpg'
    path[0].firstChild.data = img_path + '_blur.jpg'
    with open(xml_path.replace('new', 'generate') +'_blur.xml', 'w') as file:
        tree.writexml(file)

# 调整亮度
#def _changeLight(self, img):
    # random.seed(int(time.time()))
  #  flag = random.uniform(0.5, 1.5) #flag>1为调暗,小于1为调亮
  #  return exposure.adjust_gamma(img, flag)


if __name__ == '__main__':

    #scale_transform('0356_c5', 416)
    # img_blur('0001')
    #random_crop('0001',1)
    # generate new data
    names = os.listdir('/mnt/mdisk/pyq/data_aug/new_dataset/ann')
    for i in names:
        name = i[:-4]
        random_crop(name,5)

        #img_blur(name)
        print('----')
    # file_list = os.listdir(source_file)
    # for n in file_list:
    #     save_path = target_file + n
    #     tree = xml.dom.mindom.parse(save_path)
    #     root = tree.documentElement
    #     filename = root.getElementsByTagName('filename')
    #     filename[0].firstChild.data = n.replace('.xml','.jpg')
    #     path = root.getElementsByTagName('path')
    #     path[0].firstChild.data = (save_path.replace('ann','pic')).replace('.xml','.jpg')
    #
    #     with open(save_path, 'w') as file:
    #         tree.writexml(file)


代码参考:(https://blog.csdn.net/weixin_41644725/article/details/85678348)

猜你喜欢

转载自blog.csdn.net/weixin_40589607/article/details/96765185
今日推荐