pytorch based on DistributedDataParallel for distributed training of single-machine multi-card

The purpose of this article is to realize the distributed training of single-machine and multi-card based on pytorch in one article, and the multi-machine and multi-card will not be recorded for the time being. There is no such content as the principle of pytorch distributed training. The purpose is to use multiple GPUs directly and quickly through a few steps, including save and load of distributed models. The previous article has a simple record, but there are some problems and it is not detailed enough.

pytorch implements single-machine multi-card with DataParallel and DistributedDataParallel, that is, DP and DDP .

DP:

DDP:

The former DP is relatively simple, just two lines of code, but it is not truly distributed. The latter can realize that different GPUs occupy basically the same video memory. Only the latter is mentioned here.

1. Training code and startup

from torch.utils.data import Dataset, DataLoader
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP

#step1:定义通信方式和device,这里device一般用命令行的的方式
#在使用torch.distributed.launch启动时,会自动给入local_rank参数
parser = argparse.ArgumentParser()
parser.add_argument("--local_rank", type=int,default=-1)
FLAGS = parser.parse_args()
local_rank = FLAGS.local_rank
torch.cuda.set_device(local_rank)
dist.init_process_group(backend='nccl')  # nccl的后端通信方式
device = torch.device("cuda", local_rank)


#step2:分发数据,很重要的一步
train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
val_sampler = torch.utils.data.distributed.DistributedSampler(val_dataset)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=False, sampler=train_sampler,num_workers=2)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, sampler=val_sampler,num_workers=2) #此处shuffle需要为False,可以自行在此之前先进行shuffle操作。


#setep3:初始化训练模型,使用DDP的方式
model = MyModel().to(device)#自己的模型
#model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)#按照实际情况进行同步BN
model = DDP(model,find_unused_parameters=True ,device_ids=[local_rank], output_device=local_rank) #DDP方式初始化模型,这种方式会在模型的key上带上"module"

#setep4:训练每个epoch时
for epoch in range(1, CFG.epochs + 1):
    train_loader.sampler.set_epoch(epoch) # 各个进程之间相同种子数


Start training:

CUDA_VISIBLE_DEVICES=0,1 python -m torch.distributed.launch --nproc_per_node=2 train.py
#--nproc_per_node=2 一般用几张卡,就设为几

2. Save the model

if dist.get_rank() == 0:#根据情况,保存一个卡上(0卡)的模型或者都保存,都保存的话注意模型文件的名字
    temp_model_path = CFG.model_save_dir + "/"+ "temp_{}".format(epoch)+ "_" + ".pth"
    torch.save(model.state_dict(), temp_model_path)

3. Model loading

Using the above saving model, DDP will bring "module" when saving. According to your own saving situation, if the key has "module", you can use the following method to remove it, or you can change it when saving the model.

from collections import OrderedDict
checkpoint = torch.load(pathmodel, map_location=torch.device('cpu'))
new_state_dict = OrderedDict()
for k,v in checkpoint.items():
    name = k.replace("module.","") # remove `module`
    new_state_dict[name] = v
    model.load_state_dict(new_state_dict)

So far, follow these steps to change the single-machine single-card to single-machine multi-card distributed training in 3 minutes.

If you find it useful, please give the blogger a thumbs up.

Follow up:

CSDN's posting assistant said that the quality of the article was detected to be low, so I will test it, how many lines should be written to prevent the detection of low quality

You should optimize the product well, don’t fix it, it’s true that the quality of the article is low, so I typed a few more lines, I wonder, is this posting assistant built by if hard logic? , Liliyuan Shangpu.

Added four lines and indeed there is this hint that the quality is too low

The addition of the five elements does have a hint that the quality is too low

Added six lines and there is no hint that the quality is too low

Guess you like

Origin blog.csdn.net/qq_36276587/article/details/123913384