yolov6训练yolov5格式数据集

主要还是懒得修改数据集,大部分数据集都是yolov5格式,想的就是直接拿过来就用就行了

yolov6数据集

        yolov6数据集的文件层级:

custom_dataset
├── images
│   ├── train
│   │   ├── train0.jpg
│   │   └── train1.jpg
│   ├── val
│   │   ├── val0.jpg
│   │   └── val1.jpg
│   └── test
│       ├── test0.jpg
│       └── test1.jpg
└── labels
    ├── train
    │   ├── train0.txt
    │   └── train1.txt
    ├── val
    │   ├── val0.txt
    │   └── val1.txt
    └── test
        ├── test0.txt
        └── test1.txt

        yolov6数据集配置文件:

# Please insure that your custom_dataset are put in same parent dir with YOLOv6_DIR
train: ../custom_dataset/images/train # train images
val: ../custom_dataset/images/val # val images
test: ../custom_dataset/images/test # test images (optional)

# whether it is coco dataset, only coco dataset should be set to True.
is_coco: False

# Classes
nc: 20  # number of classes
names: ['aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 'car', 'cat', 'chair', 'cow', 'diningtable', 'dog',
        'horse', 'motorbike', 'person', 'pottedplant', 'sheep', 'sofa', 'train', 'tvmonitor']  # class names

yolov5数据集

        yolov5数据集的文件层级:

custom_dataset
├── train
│   ├── images
│   │   ├── train0.jpg
│   │   └── train1.jpg
│   └── labels
│       ├── train0.txt
│       └── train1.txt
└── valid
    ├── images
    │   ├── val0.jpg
    │   └── val1.jpg
    └── labels
        ├── val0.txt
        └── val1.txt

        如下图 

        train和valid文件夹中的层级 

 

        yolov5数据集配置文件:

path: E:\data
train: train/images
val: valid/images
nc: 2
names:
- head
- person

        在使用源码读取的时候会读不到labels的文件夹

        修改如下py文件的代码

        YOLOv6\yolov6\core\engine.py中48行

        增加:

from pathlib import Path
FILE = Path(__file__).resolve()
ROOT = FILE.parents[1]
path = Path(self.data_dict.get('path') or '')
if not path.is_absolute():
    path = (ROOT / path).resolve()
for k in 'train', 'val', 'test':
    if self.data_dict.get(k):  # prepend path
        self.data_dict[k] = str(path / self.data_dict[k]) if isinstance(self.data_dict[k], str) else [
            str(path / x) for x in self.data_dict[k]]
class Trainer:
    def __init__(self, args, cfg, device):
        self.args = args
        self.cfg = cfg
        self.device = device

        if args.resume:
            self.ckpt = torch.load(args.resume, map_location='cpu')

        self.rank = args.rank
        self.local_rank = args.local_rank
        self.world_size = args.world_size
        self.main_process = self.rank in [-1, 0]
        self.save_dir = args.save_dir
        # get data loader
        self.data_dict = load_yaml(args.data_path)
        self.num_classes = self.data_dict['nc']
        # -----
        from pathlib import Path
        FILE = Path(__file__).resolve()
        ROOT = FILE.parents[1]
        path = Path(self.data_dict.get('path') or '')
        if not path.is_absolute():
            path = (ROOT / path).resolve()
        for k in 'train', 'val', 'test':
            if self.data_dict.get(k):  # prepend path
                self.data_dict[k] = str(path / self.data_dict[k]) if isinstance(self.data_dict[k], str) else [
                    str(path / x) for x in self.data_dict[k]]

        self.train_loader, self.val_loader = self.get_data_loader(args, cfg, self.data_dict)
        # get model and optimizer
        model = self.get_model(args, cfg, self.num_classes, device)
        if self.args.distill:
            self.teacher_model = self.get_teacher_model(args, cfg, self.num_classes, device)
        if self.args.quant:
            self.quant_setup(model, cfg, device)
        if cfg.training_mode == 'repopt':
            scales = self.load_scale_from_pretrained_models(cfg, device)
            reinit = False if cfg.model.pretrained is not None else True
            self.optimizer = RepVGGOptimizer(model, scales, args, cfg, reinit=reinit)
        else:
            self.optimizer = self.get_optimizer(args, cfg, model)
        self.scheduler, self.lf = self.get_lr_scheduler(args, cfg, self.optimizer)
        self.ema = ModelEMA(model) if self.main_process else None
        # tensorboard
        self.tblogger = SummaryWriter(self.save_dir) if self.main_process else None
        self.start_epoch = 0
        # resume
        if hasattr(self, "ckpt"):
            resume_state_dict = self.ckpt['model'].float().state_dict()  # checkpoint state_dict as FP32
            model.load_state_dict(resume_state_dict, strict=True)  # load
            self.start_epoch = self.ckpt['epoch'] + 1
            self.optimizer.load_state_dict(self.ckpt['optimizer'])
            if self.main_process:
                self.ema.ema.load_state_dict(self.ckpt['ema'].float().state_dict())
                self.ema.updates = self.ckpt['updates']
        self.model = self.parallel_model(args, model, device)
        self.model.nc, self.model.names = self.data_dict['nc'], self.data_dict['names']

        self.max_epoch = args.epochs
        self.max_stepnum = len(self.train_loader)
        self.batch_size = args.batch_size
        self.img_size = args.img_size
        self.vis_imgs_list = []
        self.write_trainbatch_tb = args.write_trainbatch_tb
        # set color for classnames
        self.color = [tuple(np.random.choice(range(256), size=3)) for _ in range(self.model.nc)]

        self.loss_num = 3
        self.loss_info = ['Epoch', 'iou_loss', 'dfl_loss', 'cls_loss']
        if self.args.distill:
            self.loss_num += 1
            self.loss_info += ['cwd_loss']

        YOLOv6\yolov6\data\datasets.py中的get_imgs_labels函数,即268行

        修改

try:
    label_dir = osp.join(
        osp.dirname(osp.dirname(img_dir)), "labels", osp.basename(img_dir)
    )
    assert osp.exists(label_dir), f"{label_dir} is an invalid directory path!"
except:
    label_dir = osp.join(
        osp.dirname(img_dir), "labels"
    )
    assert osp.exists(label_dir), f"{label_dir} is an invalid directory path!"
    def get_imgs_labels(self, img_dir):

        assert osp.exists(img_dir), f"{img_dir} is an invalid directory path!"
        valid_img_record = osp.join(
            osp.dirname(img_dir), "." + osp.basename(img_dir) + ".json"
        )
        NUM_THREADS = min(8, os.cpu_count())

        img_paths = glob.glob(osp.join(img_dir, "**/*"), recursive=True)
        img_paths = sorted(
            p for p in img_paths if p.split(".")[-1].lower() in IMG_FORMATS and os.path.isfile(p)
        )
        assert img_paths, f"No images found in {img_dir}."

        img_hash = self.get_hash(img_paths)
        if osp.exists(valid_img_record):
            with open(valid_img_record, "r") as f:
                cache_info = json.load(f)
                if "image_hash" in cache_info and cache_info["image_hash"] == img_hash:
                    img_info = cache_info["information"]
                else:
                    self.check_images = True
        else:
            self.check_images = True

        # check images
        if self.check_images and self.main_process:
            img_info = {}
            nc, msgs = 0, []  # number corrupt, messages
            LOGGER.info(
                f"{self.task}: Checking formats of images with {NUM_THREADS} process(es): "
            )
            with Pool(NUM_THREADS) as pool:
                pbar = tqdm(
                    pool.imap(TrainValDataset.check_image, img_paths),
                    total=len(img_paths),
                )
                for img_path, shape_per_img, nc_per_img, msg in pbar:
                    if nc_per_img == 0:  # not corrupted
                        img_info[img_path] = {"shape": shape_per_img}
                    nc += nc_per_img
                    if msg:
                        msgs.append(msg)
                    pbar.desc = f"{nc} image(s) corrupted"
            pbar.close()
            if msgs:
                LOGGER.info("\n".join(msgs))

            cache_info = {"information": img_info, "image_hash": img_hash}
            # save valid image paths.
            with open(valid_img_record, "w") as f:
                json.dump(cache_info, f)

        # check and load anns
        # --------
        try:
            label_dir = osp.join(
                osp.dirname(osp.dirname(img_dir)), "labels", osp.basename(img_dir)
            )
            assert osp.exists(label_dir), f"{label_dir} is an invalid directory path!"
        except:
            label_dir = osp.join(
                osp.dirname(img_dir), "labels"
            )
            assert osp.exists(label_dir), f"{label_dir} is an invalid directory path!"
        # Look for labels in the save relative dir that the images are in
        def _new_rel_path_with_ext(base_path: str, full_path: str, new_ext: str):
            rel_path = osp.relpath(full_path, base_path)
            return osp.join(osp.dirname(rel_path), osp.splitext(osp.basename(rel_path))[0] + new_ext)


        img_paths = list(img_info.keys())
        label_paths = sorted(
            osp.join(label_dir, _new_rel_path_with_ext(img_dir, p, ".txt"))
            for p in img_paths
        )
        assert label_paths, f"No labels found in {label_dir}."
        label_hash = self.get_hash(label_paths)
        if "label_hash" not in cache_info or cache_info["label_hash"] != label_hash:
            self.check_labels = True

        if self.check_labels:
            cache_info["label_hash"] = label_hash
            nm, nf, ne, nc, msgs = 0, 0, 0, 0, []  # number corrupt, messages
            LOGGER.info(
                f"{self.task}: Checking formats of labels with {NUM_THREADS} process(es): "
            )
            with Pool(NUM_THREADS) as pool:
                pbar = pool.imap(
                    TrainValDataset.check_label_files, zip(img_paths, label_paths)
                )
                pbar = tqdm(pbar, total=len(label_paths)) if self.main_process else pbar
                for (
                    img_path,
                    labels_per_file,
                    nc_per_file,
                    nm_per_file,
                    nf_per_file,
                    ne_per_file,
                    msg,
                ) in pbar:
                    if nc_per_file == 0:
                        img_info[img_path]["labels"] = labels_per_file
                    else:
                        img_info.pop(img_path)
                    nc += nc_per_file
                    nm += nm_per_file
                    nf += nf_per_file
                    ne += ne_per_file
                    if msg:
                        msgs.append(msg)
                    if self.main_process:
                        pbar.desc = f"{nf} label(s) found, {nm} label(s) missing, {ne} label(s) empty, {nc} invalid label files"
            if self.main_process:
                pbar.close()
                with open(valid_img_record, "w") as f:
                    json.dump(cache_info, f)
            if msgs:
                LOGGER.info("\n".join(msgs))
            if nf == 0:
                LOGGER.warning(
                    f"WARNING: No labels found in {osp.dirname(img_paths[0])}. "
                )

        if self.task.lower() == "val":
            if self.data_dict.get("is_coco", False): # use original json file when evaluating on coco dataset.
                assert osp.exists(self.data_dict["anno_path"]), "Eval on coco dataset must provide valid path of the annotation file in config file: data/coco.yaml"
            else:
                assert (
                    self.class_names
                ), "Class names is required when converting labels to coco format for evaluating."
                save_dir = osp.join(osp.dirname(osp.dirname(img_dir)), "annotations")
                if not osp.exists(save_dir):
                    os.mkdir(save_dir)
                save_path = osp.join(
                    save_dir, "instances_" + osp.basename(img_dir) + ".json"
                )
                TrainValDataset.generate_coco_format_labels(
                    img_info, self.class_names, save_path
                )

        img_paths, labels = list(
            zip(
                *[
                    (
                        img_path,
                        np.array(info["labels"], dtype=np.float32)
                        if info["labels"]
                        else np.zeros((0, 5), dtype=np.float32),
                    )
                    for img_path, info in img_info.items()
                ]
            )
        )
        self.img_info = img_info
        LOGGER.info(
            f"{self.task}: Final numbers of valid images: {len(img_paths)}/ labels: {len(labels)}. "
        )
        return img_paths, labels

        以上两处修改之后yolov6就可以直接使用yolov5数据集

猜你喜欢

转载自blog.csdn.net/athrunsunny/article/details/126998742
今日推荐