Scaled-Yolov4训练代码一步步复现

Scaled-Yolov4训练代码一步步复现

train.py

from torch.nn.parallel import DistributedDataParallel as DDP
import yaml
import math
import os
import argparse
from pathlib import Path
import numpy as np
import torch.distributed as dist
import torch.optim as optim
import torch.optim.lr_scheduler as lr_scheduler
import torch.utils.data
from utils.torch_utils import init_seeds, intersect_dicts, ModelEMA
from models.yolo import Model
from utils.general import check_img_size, torch_distributed_zero_first
from utils.google_utils import attempt_download
from utils.datasets import create_dataloader


def train(hyp, opt, device, tb_writer=None):
    print(f"Hyparameters {
      
      hyp}")
    log_dir = Path(tb_writer.log_dir) if tb_writer else Path(opt.logdir) / "evolve"
    wdir = str(log_dir / "weights") + os.sep
    os.makedirs(wdir, exist_ok=True)
    last = wdir + "last.pt"
    best = wdir + "best.pt"
    results_file = str(log_dir / "results.txt")
    epochs, batch_size, total_batch_size, weights, rank = \
        opt.epochs, opt.batch_size, opt.total_batch_size, opt.weights, opt.global_rank
    # TODO : Use DDP logging. Only the first process is allowed to log.
    # Save run settings
    with open(log_dir / "hyp.yaml", 'w') as f:
        # sort_keys=False:show original order,default set True
        yaml.dump(hyp, f, sort_keys=False)
    with open(log_dir / "opt.yaml", 'w') as f:
        yaml.dump(vars(opt), f, sort_keys=False)
    # Configure
    cuda = device.type != "cpu"
    init_seeds(2 + rank)
    # custom data ->receive the type of the data that is dict
    with open(opt.data, 'r') as f:
        data_dict = yaml.load(f, Loader=yaml.FullLoader)
    train_path = data_dict["train"]
    test_path = data_dict["test"]
    nc, names = (1, ["item"]) if opt.single_cls else (int(data_dict["nc"]), \
                                                      data_dict["names"])
    assert len(names) == nc, \
        f"{
      
      len(names)} names found for nc={
      
      nc} datasets in {
      
      opt.data}"
    # model mains to load pre_model
    pretrained = weights.endswith(".pt")
    if pretrained:
        with torch_distributed_zero_first(rank):
            attempt_download(weights)
        checkpoint = torch.load(weights, map_location=device)
        model = Model(opt.cfg or checkpoint["model"].yaml, nc=nc, ch=3)
        exclude = ["anchor"] if opt.cfg else []
        state_dict = checkpoint["model"].float().state_dict()
        # judge network layers whether it equals
        state_dict = intersect_dicts(state_dict, model.state_dict(), exclude)
        model.load_state_dict(state_dict, strict=False)
        print(f"Transferred {
      
      len(state_dict)}/{
      
      len(model.state_dict())} item from {
      
      weights}")
    else:
        model = Model(opt.cfg, nc=nc, ch=3).to(device)
    # Optimizer
    nbs = 64  # nominal batch size
    # round() only have one parameter , means rounding(四舍五入)
    # round() when have two parameters , assign the decimal places to keep for the first param
    accumulate = max(round(nbs / total_batch_size), 1)
    hyp["weight_decay"] *= total_batch_size * accumulate / nbs
    # optimizer parameter groups
    pg0, pg1, pg2 = [], [], []
    for k, v in model.named_parameters():
        v.requires_grad = True
        if ".bias" in k:
            pg2.append(v)
        elif ".weights" in k and ".bn" not in k:
            pg1.append(v)
        else:
            pg0.append(v)
    if opt.adam:
        optimizer = optim.Adam(pg0, lr=hyp["lr0"], betas=(hyp["momentum"], 0.999))
    else:
        optimizer = optim.SGD(pg0, lr=hyp["lr0"], momentum=hyp["momentum"], nesterov=True)
    optimizer.add_param_group({
    
    "params": pg1, "weight_decay": hyp["weight_decay"]})
    optimizer.add_param_group({
    
    "params": pg2})
    print(f"Optimizer groups:{
      
      len(pg1)}.bias,{
      
      len(pg2)}conv_weights,{
      
      len(pg0)}.others")
    del pg0, pg1, pg2
    # lr decay
    lf = lambda x: (((1 + math.cos(x * math.pi / epochs)) / 2) ** 1.0) * 0.8 + 0.2
    scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lf)
    # Resume
    start_epoch, best_fitness = 0, 0.0
    if pretrained:
        # Optimizer
        if checkpoint["optimizer"] is not None:
            optimizer.load_state_dict(checkpoint["optimizer"])
            best_fitness = checkpoint["best_fitness"]
        # Results
        if checkpoint.get("training_results") is not None:
            with open(results_file, 'w') as f:
                f.write(checkpoint["training_results"])
        # Epoch
        start_epoch = checkpoint["epoch"] + 1
        if epochs < start_epoch:
            print('%s has been trained for %g epochs. Fine-tuning for %g additional epochs.' %
                  (weights, checkpoint['epoch'], epochs))
            epochs += checkpoint["epoch"]
    del checkpoint, state_dict
    # Image sizes
    gs = int(max(model.stride))
    imgsz, imgsz_test = [check_img_size(x, gs) for x in opt.img_size]

    # DP model
    if cuda and rank == -1 and torch.cuda.device_count() > 1:
        model = torch.nn.parallel.DataParallel(model)

    # SyncBatchNorm multiple device train
    if opt.sync_bn and cuda and rank != -1:
        model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model).to(device)
        print("Using SyncBatchNorm")

    # Exponential moving average
    ema = ModelEMA(model) if rank in [-1, 0] else None

    # DDP model
    if cuda and rank != -1:
        model = DDP(model, device_ids=[opt.local_rank], output_device=[opt.local_rank])

    # TrainLoader
    dataloader, dataset = create_dataloader(train_path, imgsz, batch_size, gs, opt,
                                            cache=opt.cache_images, rect=opt.rect,
                                            local_rank=rank, world_size=opt.world.size)

    mlc = np.concatenate(dataset.labels, 0)[:, 0].max()
    # TestLoader


if __name__ == "__main__":
    parse = argparse.ArgumentParser()
    parse.add_argument("--weights", type=str, default="yolov4-p5.pt")
    parse.add_argument("--cfg", type=str, default="")
    parse.add_argument("--data", type=str, default="data/coco128.yaml")
    parse.add_argument("--hyp", type=str, default="")
    parse.add_argument("--epochs", type=int, default=300)
    parse.add_argument("--batch-size", type=int, default=16)
    parse.add_argument("--img-size", nargs='+', type=int, default=[640, 640])
    parse.add_argument("--resume", nargs='?', const='get_last', default=False)
    parse.add_argument("--nosave", action="store_true")
    parse.add_argument("--notest", action="store_true")
    parse.add_argument("--noautoanchor", action="store_true")
    parse.add_argument("--evolve", action="store_true")
    parse.add_argument("--bucket", type=str, default="")
    parse.add_argument("--cache-images", action="store_true")
    parse.add_argument("--name", default="")
    parse.add_argument("--device", default="")
    parse.add_argument("--multi-scale", action="store_true")
    parse.add_argument("--single-cls", action="store_true")
    parse.add_argument("--adam", action="store_true")
    parse.add_argument("--sync-bn", action="store_true")
    parse.add_argument("--local-rank", type=int, default=-1)
    parse.add_argument("--logdir", type=str, default="runs/")
    opt = parse.parse_args()
    # DDP mode
    if opt.local_rank != -1:
        assert torch.cuda.device_count() > opt.local_rank
        torch.cuda.set_device(opt.local_rank)
        device = torch.device("cuda", opt.local_rank)
        dist.init_process_group(backend="nccl", init_method="env://")
        # return the number of process in the current group of the process
        opt.world_size = dist.get_world_size()
        # return rank of the current process group
        opt.global_rank = dist.get_rank()
        assert opt.batch_size % opt.world_size == 0, "--batch-size must be multiple of CUDA device count"
        opt.batch_size = opt.total_batch_size // opt.world_size
    print(opt)
    with open(opt.hyp) as f:
        # load hyp
        hyp = yaml.load(f, Loader=yaml.FullLoader)

torch_utils.py

import math
import os
import time
from copy import deepcopy
import torch
import torch.backends.cudnn as cudnn
import torch.nn as nn
import torch.functional as F
import torchvision.models as models


def init_seeds(seed=0):
    torch.manual_seed(seed)

    if seed == 0:
        cudnn.deterministic = True
        cudnn.benchmark = False
    else:
        # benchmark model that can improve the speed of the computer running
        # but due to random of the computer inner ,
        # everytime the front of the network products the results that is different slight
        cudnn.deterministic = False
        cudnn.benchmark = True


def intersect_dicts(da, db, exclude=()):
    # judge if statement whether it meets condition ,yes ->{k:v} no ->{}
    return {
    
    k: v for k, v in da.items() \
            if k in db and not any(x in k for x in exclude) and v.shape == db[k].shape}


def is_parallel(model):
    return type(model) in (nn.parallel.DataParallel, nn.parallel.DistributedDataParallel)


def copy_attr(a, b, include=(), exclude=()):
    # Copy attributes from b to a, options to only include [...] and to exclude [...]
    for k, v in b.__dict__.items():
        if (len(include) and k not in include) or k.startswith('_') or k in exclude:
            continue
        else:
            setattr(a, k, v)


class ModelEMA:
    """ Model Exponential Moving Average from https://github.com/rwightman/pytorch-image-models
    Keep a moving average of everything in the model state_dict (parameters and buffers).
    This is intended to allow functionality like
    https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage
    A smoothed version of the weights is necessary for some training schemes to perform well.
    This class is sensitive where it is initialized in the sequence of model init,
    GPU assignment and distributed training wrappers.
    """

    def __init__(self, model, decay=0.9999, updates=0):
        # Create EMA
        self.ema = deepcopy(model.module if is_parallel(model) else model).eval()  # FP32 EMA
        # if next(model.parameters()).device.type != 'cpu':
        #     self.ema.half()  # FP16 EMA
        self.updates = updates  # number of EMA updates
        self.decay = lambda x: decay * (1 - math.exp(-x / 2000))  # decay exponential ramp (to help early epochs)
        for p in self.ema.parameters():
            p.requires_grad_(False)

    def update(self, model):
        # Update EMA parameters
        with torch.no_grad():
            self.updates += 1
            d = self.decay(self.updates)

            msd = model.module.state_dict() if is_parallel(model) else model.state_dict()  # model state_dict
            for k, v in self.ema.state_dict().items():
                if v.dtype.is_floating_point:
                    v *= d
                    v += (1. - d) * msd[k].detach()

    def update_attr(self, model, include=(), exclude=('process_group', 'reducer')):
        # Update EMA attributes
        copy_attr(self.ema, model, include, exclude)

general.py

import math
from contextlib import contextmanager
import torch


# context manager -> __exit__ __enter__ ,implement key words -> with function
# local_rank -> 0 or -1 represent main process,  local process main read and cache data, sub_process read cache data
# barrier() -> data intercommunicate(process lock,when main process execute,other subprocess wait)
@contextmanager
def torch_distributed_zero_first(local_rank: int):
    # if main process not lock
    # module "torch.distributed parallel" support,only linux,not support windows
    if local_rank not in [-1, 0]:
        torch.distributed().barrier()
    yield
    # if child process is lock
    if local_rank == 0:
        torch.distributed().barrier()


def check_img_size(imgsz, s=32):
    # Verify imgsz is a multiple of stride s
    new_size = make_divisible(imgsz, int(s))
    if new_size != imgsz:
        print("WARNING!")
    return new_size


# divisible by
def make_divisible(x, s):
    return math.ceil(x / s) * s

datasets.py

import glob
import os
from pathlib import Path

import numpy as np
import torch
from PIL import Image, ExifTags
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm
from utils.general import torch_distributed_zero_first

help_url = ""
vid_formats = ['.mov', '.avi', '.mp4', '.mpg', '.mpeg', '.m4v', '.wmv', '.mkv']
img_formats = ['.bmp', '.jpg', '.jpeg', '.png', '.tif', '.tiff', '.dng']

for orientation in ExifTags.TAGS.keys():
    if ExifTags.TAGS[orientation] == 'Orientation':
        break


def get_hash(files):
    # hash value of the files list
    return sum(os.path.getsize(f) for f in files if os.path.isfile(f))


def exit_size(img):
    s = img.size
    try:
        rotation = dict(img._getexif().items())[orientation]
        if rotation == 6:  # rotation 270
            s = (s[1], s[0])
        elif rotation == 8:  # rotation 90
            s = (s[1], s[0])
    except:
        pass
    return s


def create_dataloader(path, imgsz, batch_size, stride, opt, hyp=None,
                      augment=False,
                      cache=False, pad=0.0, rect=False, local_rank=-1, world_size=1):
    with torch_distributed_zero_first(local_rank):
        dataset = LoadImagesAndLabels(path, imgsz, batch_size,
                                      augment=augment,  # augment images
                                      hyp=hyp,  # augmentation hyper_parameters
                                      rect=rect,  # rectangular training
                                      cache_images=cache,
                                      single_cls=opt.single_cls,
                                      stride=int(stride),
                                      pad=pad)
        batch_size = min(batch_size, len(dataset))
        # number of workers
        nw = min([os.cpu_count() // world_size, batch_size if batch_size > 1 else 0, 8])
        train_sampler = torch.utils.data.distributed.DistributedSampler(dataset) if local_rank != -1 else None
        dataloader = DataLoader(
            dataset,
            batch_size=batch_size,
            num_workers=nw,
            sampler=train_sampler,
            # enough memory ,set pin_memory is True
            pin_memory=True,  # tensor of the memory quickly transform to gpu
            collate_fn=LoadImagesAndLabels.collate_fn
        )
    return dataloader, dataset


class LoadImagesAndLabels(Dataset):
    def __init__(self, path, img_size=640, batch_size=16, augment=False,
                 hyp=False, rect=False, image_weights=False,
                 cache_images=False, single_cls=False, stride=32, pad=0.0):
        try:
            f = []  # image files
            for p in path if isinstance(path, list) else [path]:
                p = str(Path(p))
                parent = str(Path(p).parent) + os.sep
                if os.path.isfile(p):
                    with open(p, 'r') as t:
                        t = t.read().splitlines()
                        f += [x.replace("./", parent) if x.startswith("./") else x for x in t]
                elif os.path.isdir(p):
                    f += glob.iglob(p + os.sep + "*.*")
                else:
                    raise Exception(f"{
      
      p} does not exist")
            self.img_files = sorted(
                [x.replace('/', os.sep) for x in f if os.path.splitext(x)[-1].lower() in img_formats]
            )
        except Exception as e:
            raise Exception("Error loading data from %s: %s\nSee %s" % (path, e, help_url))
        n = len(self.img_files)
        assert n > 0, "No images found in %s.See %s" % (path, help_url)
        # bi : every image that correspond itself batch index respectively,
        # ep dataset_number =10, batch_size = 4,img1,img2,img3,img4 correspond idx [0,0,0,0]
        # img1,img2,img3,img4 belong to the same batch
        # nb : number of the batches
        bi = np.floor(np.arange(n) / batch_size).astype(np.int)  # batch_index
        nb = bi[-1] + 1  # number of batches
        self.n = n  # the size of yours datasets
        self.batch = bi
        self.img_size = img_size
        self.augment = augment
        self.hyp = hyp
        self.image_weights = image_weights
        self.rect = False if image_weights else rect
        self.mosaic = self.augment and not self.rect
        self.mosaic_border = [-img_size // 2, -img_size // 2]
        self.stride = stride
        # Define labels
        self.label_files = [x.replace("images", "labels") \
                                .replace(os.path.splitext(x)[-1], ".txt") \
                            for x in self.img_files]
        # Check cache
        cache_path = str(Path(self.label_files[0]).parent) + ".cache"
        if os.path.isfile(cache_path):
            cache = torch.load(cache_path)
            if cache["hash"] != get_hash(self.label_files + self.img_files):
                cache = self.cache_labels(cache_path)
        else:
            cache = self.cache_labels(cache_path)
        # Get labels
        labels, shapes = zip(*[cache[x] for x in self.img_files])
        self.shapes = np.array(shapes, dtype=np.float64)
        self.labels = list(labels)

    def cache_labels(self, path="labels.cache"):
        x = {
    
    }
        # tqdm zip() package data,desc is the name of the pbar,total is the len of the pbar
        pbar = tqdm(zip(self.img_files, self.label_files), desc="Scanning images", total=len(self.img_files))
        for (img, label) in pbar:
            try:
                l = []
                image = Image.open(img)
                image.verify()
                shape = exit_size(image)
                assert (shape[0] > 9) & (shape[1] > 9), "image size <10 pixels"
                if os.path.isfile(label):
                    with open(label, 'r') as f:
                        l = np.array([x.split() for x in f.read().splitlines()], dtype=np.float32)
                if len(l) == 0:
                    l = np.zeros((0, 5), dtype=np.float32)
                x[img] = [l, shape]
            except Exception as e:
                x[img] = None
                print(f"WARNING: {
      
      img}: {
      
      e}")
        x["hash"] = get_hash(self.label_files + self.img_files)
        torch.save(x, path)
        return x

    def __getitem__(self, item):
        pass

    def __len__(self):
        return len(self.img_files)

猜你喜欢

转载自blog.csdn.net/qq_35140742/article/details/120528096