1 Initialize the process group
import os
from torch import distributed
try:
world_size = int(os.environ["WORLD_SIZE"]) # 全局进程个数
rank = int(os.environ["RANK"]) # 当前进程编号(全局)
local_rank = int(os.environ["LOCAL_RANK"]) # 每台机器上的进程编号(局部)
distributed.init_process_group("nccl") # 初始化进程, 使用nccl后端
except KeyError:
world_size = 1
rank = 0
local_rank = 0
distributed.init_process_group(
backend="nccl",
init_method="tcp://127.0.0.1:12584",
rank=rank,
world_size=world_size,
)
2 Use DistributedSampler to divide the dataset
Different from nn.DataParallel, the batch_size in distributed training is the number of input samples of a single card, because it represents the corresponding partition under the current rank, and the total batch_size is the batch_size here multiplied by the number of parallels. For example, suppose you use 8 cards to train the model, the batch_size in nn.DataParallel is 3200, and the batch_size in nn.DistributedDataParallel is 400,
from dataloader.distributed_sampler import DistributedSampler
train_sampler = DistributedSampler(
train_set, num_replicas=world_size, rank=rank, shuffle=True, seed=seed)
trainloader = DataLoader(
dataset=train_set,
pin_memory=true,
batch_size=batch_size,
num_workers=num_workers,
sampler=train_sampler
) # pin_memory: 是否提前申请CUDA内存. 创建DataLoader时,设置pin_memory=True,则意味着生成的Tensor数据最开始是属于内存中的锁页内存,这样将内存的Tensor转义到GPU的显存就会更快一些.
3 Use DistributedDataParallel to encapsulate the model
DistributedDataParallel can perform all reduce for the gradients obtained on different GPUs (that is, summarize the gradients calculated by different GPUs and synchronize the calculation results). After all reduce, the gradients of the models in different GPUs are the mean value of the gradients of each GPU before all reduce,
backbone = get_model(
cfg.network, dropout=0.0, fp16=cfg.fp16, num_features=cfg.embedding_size).cuda()
backbone = torch.nn.parallel.DistributedDataParallel(
module=backbone, broadcast_buffers=False, device_ids=[local_rank], bucket_cap_mb=16,
find_unused_parameters=True)
4 training model
Load the input image, label and model into the GPU used by the current process,
for epoch in range(start_epoch, cfg.num_epoch):
if isinstance(train_loader, DataLoader):
# 设置train_loader中的sampler的epoch,DistributedSampler需要这个参数来维持各个进程之间的相同随机数种子
train_loader.sampler.set_epoch(epoch)
for _, (img, local_labels) in enumerate(train_loader):
global_step += 1
local_embeddings = backbone(img)
loss: torch.Tensor = module_partial_fc(local_embeddings, local_labels, opt)
loss.backward()
torch.nn.utils.clip_grad_norm_(backbone.parameters(), 5)
opt.step()
opt.zero_grad()
lr_scheduler.step()
5 Calculate the loss
distributed.all_gather(tensor_list, input_tensor): Gather the specified input_tensor from all devices and place it in the tensor_list variable on all devices,
from torch import distributed
distributed.all_gather(_gather_embeddings, local_embeddings)
distributed.all_gather(_gather_labels, local_labels)
distributed.all_reduce(loss, distributed.ReduceOp.SUM)
6 Save the model
if rank == 0:
path_module = os.path.join(cfg.output, "model_final.pt")
torch.save(backbone.module.state_dict(), path_module)
7 Start parallel program
(1) use torch.distributed.launch
This command will make the script run n times in parallel (n is the number of GPUs used),
python -m torch.distributed.launch --nproc_per_node=8 train.py configs/ms1mv3_r50
(2) use torch.multiprocessing
torch.multiprocessing will automatically create a process, bypassing some minor problems of torch.distributed.launch opening and exiting the process,
def main(rank):
pass
torch.multiprocessing.spawn(main, nprocs, args)
8 code example
Refer to the insightface code,
import argparse
import logging
import os
from datetime import datetime
import numpy as np
import torch
from backbones import get_model
from dataset import get_dataloader
from losses import CombinedMarginLoss
from lr_scheduler import PolyScheduler
from partial_fc import PartialFC, PartialFCAdamW
from torch import distributed
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from utils.utils_callbacks import CallBackLogging, CallBackVerification
from utils.utils_config import get_config
from utils.utils_distributed_sampler import setup_seed
from utils.utils_logging import AverageMeter, init_logging
assert torch.__version__ >= "1.12.0", "In order to enjoy the features of the new torch, \
we have upgraded the torch to 1.12.0. torch before than 1.12.0 may not work in the future."
try:
rank = int(os.environ["RANK"])
local_rank = int(os.environ["LOCAL_RANK"])
world_size = int(os.environ["WORLD_SIZE"])
distributed.init_process_group("nccl")
except KeyError:
rank = 0
local_rank = 0
world_size = 1
distributed.init_process_group(
backend="nccl",
init_method="tcp://127.0.0.1:12584",
rank=rank,
world_size=world_size,
)
def main(args):
# get config
cfg = get_config(args.config)
# global control random seed
setup_seed(seed=cfg.seed, cuda_deterministic=False)
torch.cuda.set_device(local_rank)
os.makedirs(cfg.output, exist_ok=True)
init_logging(rank, cfg.output)
summary_writer = (
SummaryWriter(log_dir=os.path.join(cfg.output, "tensorboard"))
if rank == 0
else None
)
wandb_logger = None
if cfg.using_wandb:
import wandb
# Sign in to wandb
try:
wandb.login(key=cfg.wandb_key)
except Exception as e:
print("WandB Key must be provided in config file (base.py).")
print(f"Config Error: {e}")
# Initialize wandb
run_name = datetime.now().strftime("%y%m%d_%H%M") + f"_GPU{rank}"
run_name = run_name if cfg.suffix_run_name is None else run_name + f"_{cfg.suffix_run_name}"
try:
wandb_logger = wandb.init(
entity = cfg.wandb_entity,
project = cfg.wandb_project,
sync_tensorboard = True,
resume=cfg.wandb_resume,
name = run_name,
notes = cfg.notes) if rank == 0 or cfg.wandb_log_all else None
if wandb_logger:
wandb_logger.config.update(cfg)
except Exception as e:
print("WandB Data (Entity and Project name) must be provided in config file (base.py).")
print(f"Config Error: {e}")
train_loader = get_dataloader(
cfg.rec,
local_rank,
cfg.batch_size,
cfg.dali,
cfg.seed,
cfg.num_workers
)
backbone = get_model(
cfg.network, dropout=0.0, fp16=cfg.fp16, num_features=cfg.embedding_size).cuda()
backbone = torch.nn.parallel.DistributedDataParallel(
module=backbone, broadcast_buffers=False, device_ids=[local_rank], bucket_cap_mb=16,
find_unused_parameters=True)
backbone.train()
# FIXME using gradient checkpoint if there are some unused parameters will cause error
backbone._set_static_graph()
margin_loss = CombinedMarginLoss(
64,
cfg.margin_list[0],
cfg.margin_list[1],
cfg.margin_list[2],
cfg.interclass_filtering_threshold
)
if cfg.optimizer == "sgd":
module_partial_fc = PartialFC(
margin_loss, cfg.embedding_size, cfg.num_classes,
cfg.sample_rate, cfg.fp16)
module_partial_fc.train().cuda()
# TODO the params of partial fc must be last in the params list
opt = torch.optim.SGD(
params=[{"params": backbone.parameters()}, {"params": module_partial_fc.parameters()}],
lr=cfg.lr, momentum=0.9, weight_decay=cfg.weight_decay)
elif cfg.optimizer == "adamw":
module_partial_fc = PartialFCAdamW(
margin_loss, cfg.embedding_size, cfg.num_classes,
cfg.sample_rate, cfg.fp16)
module_partial_fc.train().cuda()
opt = torch.optim.AdamW(
params=[{"params": backbone.parameters()}, {"params": module_partial_fc.parameters()}],
lr=cfg.lr, weight_decay=cfg.weight_decay)
else:
raise
cfg.total_batch_size = cfg.batch_size * world_size
cfg.warmup_step = cfg.num_image // cfg.total_batch_size * cfg.warmup_epoch
cfg.total_step = cfg.num_image // cfg.total_batch_size * cfg.num_epoch
lr_scheduler = PolyScheduler(
optimizer=opt,
base_lr=cfg.lr,
max_steps=cfg.total_step,
warmup_steps=cfg.warmup_step,
last_epoch=-1
)
start_epoch = 0
global_step = 0
if cfg.resume:
dict_checkpoint = torch.load(os.path.join(cfg.output, f"checkpoint_gpu_{rank}.pt"))
start_epoch = dict_checkpoint["epoch"]
global_step = dict_checkpoint["global_step"]
backbone.module.load_state_dict(dict_checkpoint["state_dict_backbone"])
module_partial_fc.load_state_dict(dict_checkpoint["state_dict_softmax_fc"])
opt.load_state_dict(dict_checkpoint["state_optimizer"])
lr_scheduler.load_state_dict(dict_checkpoint["state_lr_scheduler"])
del dict_checkpoint
for key, value in cfg.items():
num_space = 25 - len(key)
logging.info(": " + key + " " * num_space + str(value))
callback_verification = CallBackVerification(
val_targets=cfg.val_targets, rec_prefix=cfg.rec,
summary_writer=summary_writer, wandb_logger = wandb_logger
)
callback_logging = CallBackLogging(
frequent=cfg.frequent,
total_step=cfg.total_step,
batch_size=cfg.batch_size,
start_step = global_step,
writer=summary_writer
)
loss_am = AverageMeter()
amp = torch.cuda.amp.grad_scaler.GradScaler(growth_interval=100)
for epoch in range(start_epoch, cfg.num_epoch):
if isinstance(train_loader, DataLoader):
train_loader.sampler.set_epoch(epoch)
for _, (img, local_labels) in enumerate(train_loader):
global_step += 1
local_embeddings = backbone(img)
loss: torch.Tensor = module_partial_fc(local_embeddings, local_labels, opt)
if cfg.fp16:
amp.scale(loss).backward()
amp.unscale_(opt)
torch.nn.utils.clip_grad_norm_(backbone.parameters(), 5)
amp.step(opt)
amp.update()
else:
loss.backward()
torch.nn.utils.clip_grad_norm_(backbone.parameters(), 5)
opt.step()
opt.zero_grad()
lr_scheduler.step()
with torch.no_grad():
if wandb_logger:
wandb_logger.log({
'Loss/Step Loss': loss.item(),
'Loss/Train Loss': loss_am.avg,
'Process/Step': global_step,
'Process/Epoch': epoch
})
loss_am.update(loss.item(), 1)
callback_logging(global_step, loss_am, epoch, cfg.fp16, lr_scheduler.get_last_lr()[0], amp)
if global_step % cfg.verbose == 0 and global_step > 0:
callback_verification(global_step, backbone)
if cfg.save_all_states:
checkpoint = {
"epoch": epoch + 1,
"global_step": global_step,
"state_dict_backbone": backbone.module.state_dict(),
"state_dict_softmax_fc": module_partial_fc.state_dict(),
"state_optimizer": opt.state_dict(),
"state_lr_scheduler": lr_scheduler.state_dict()
}
torch.save(checkpoint, os.path.join(cfg.output, f"checkpoint_gpu_{rank}.pt"))
if rank == 0:
path_module = os.path.join(cfg.output, "model.pt")
torch.save(backbone.module.state_dict(), path_module)
if wandb_logger and cfg.save_artifacts:
artifact_name = f"{run_name}_E{epoch}"
model = wandb.Artifact(artifact_name, type='model')
model.add_file(path_module)
wandb_logger.log_artifact(model)
if cfg.dali:
train_loader.reset()
if rank == 0:
path_module = os.path.join(cfg.output, "model.pt")
torch.save(backbone.module.state_dict(), path_module)
from torch2onnx import convert_onnx
convert_onnx(backbone.module.cpu().eval(), path_module, os.path.join(cfg.output, "model.onnx"))
if wandb_logger and cfg.save_artifacts:
artifact_name = f"{run_name}_Final"
model = wandb.Artifact(artifact_name, type='model')
model.add_file(path_module)
wandb_logger.log_artifact(model)
distributed.destroy_process_group()
if __name__ == "__main__":
torch.backends.cudnn.benchmark = True
parser = argparse.ArgumentParser(
description="Distributed Arcface Training in Pytorch")
parser.add_argument("config", type=str, help="py config file")
main(parser.parse_args())
Reference documents: