一起养成写作习惯!这是我参与「掘金日新计划 · 4 月更文挑战」的第7天,点击查看活动详情。
在训练神经网络模型时,我们更偏爱大规模的数据集和复杂的网络结构。虽然其可以让我们的模型表征能力更强,但同时也对计算的时间和空间提出了挑战。
为什么要使用分布式训练
随着数据集体积的不断扩大,动辄TB甚至是PB的数据量使得单卡训练成为幻影,即使对于普通的数据集和模型,更快的速度,更大的显存,更高的显卡利用率,以及更大的batchsize带来更好的性能也都是我们所追求的。
一些分布式训练算法
所谓分布式,指的是计算节点之间不共享内存,需要通过网络通信的方式交换数据,以下介绍几种不同的分布式训练算法。
Spark MLlib
Spark的分布式计算原理:
首先需要了解一下Spark分布式计算原理
额,没太看懂,大概就是
-
Driver: 创建SparkContext,来提供程序运行所需要的环境,并且负责与Cluster Manager进行通信来实现资源申请、任务分配和监控等功能,当Executor关闭时,Driver同时将SparkContext关闭;
-
Executor: 工作节点(Worker Node)中的一个进程,负责运行Task;
-
Task: 运行在Executor上的工作单元;
-
RDD: 弹性分布式数据集,是分布式内存的一个抽象概念,提供了一种高度受限的共享内存模型,其具有以下性质:
- 只读不能修改,只能通过转换生成新的RDD;
- 可以分布在多台机器上并行处理;
- 弹性:计算中内存不够会和磁盘进行数据交换;
- 基于内存:可以全部或部分缓存在内存中,在多次计算间重用。
-
partition: RDD基础处理单元是partition(分区),一个Work Node中有多个partition,每个partition的计算都在一个不同的task中进行
-
DAG: 有向无环图,反映RDD之间的依赖关系,在执行具体任务之前,Spark会将程序拆解成一个任务DAG;处理DAG最关键的过程是找到哪些是可以并行处理的部分,哪些是必须shuffle和reduce的部分。
-
shuffle: 指的是所有partition的数据必须进行洗牌后才能得到下一步的数据
-
Job: 一个Job包含多个RDD及作用于相应RDD上的各种操作;
-
Stage: 是Job的基本调度单位,一个Job会分为多组Task,每组Task被称为Stage,或者也被称为TaskSet,代表一组关联的,相互之间没有Shuffle依赖关系的任务组成的任务集;
-
Cluter Manager: 指的是在集群上获取资源的外部服务。目前有三种类型
- Standalon : spark原生的资源管理,由Master负责资源的分配;
- Apache Mesos:与hadoop MR兼容性良好的一种资源调度框架;
- Hadoop Yarn: 主要是指Yarn中的ResourceManager。
Spark MLlib并行训练原理:
简单的数据并行,缺点如下:
- 采用全局广播的方式,在每轮迭代前广播全部模型参数。
- 采用阻断式的梯度下降方式,每轮梯度下降由最慢的节点决定。 Spark等待所有节点计算完梯度之后再进行汇总。
- Spark MLlib并不支持复杂网络结构和大量可调超参,对深度学习支持较弱。
Parameter Server
Parameter Server由李沐大佬提出,如今已被各个框架应用于分布式训练当中。论文地址
在分布式训练当中,输入数据、模型、反向传播计算梯度都是可以并行的,但是更新参数依赖于所有的训练样本的梯度,这是不能并行的,如下图所示:
Parameter Server包含一个参数(server)服务器(或者GPU等)来复制分发model到各个工作(worker)服务器上,计算示意图如下:
- 将数据和初始化参数加载到server当中,若无法一次加载进来也可分多次加载;
- server将输入数据进行切片,分发给各个的worker;
- server复制模型,传递给各个worker;
- 各个worker并行进行计算(forward和backward);
- 将各个worker求得的梯度求平均,返回server进行更新(push:worker将计算的梯度传送给server),同时回到第二步,重新分发更新过后的模型参数(pull:worker从server拉取参数)。
通过上图可以看到每个worker之间没有任何信息交换,它们都只与server通信。
上述过程貌似和Spark差不多,实际上PS中的server和worker中存在很多节点,它们分别组成server group和worker group,功能与上述基本一致,对于server group,存在server manager来维护和分配各个server node的数据,如下图所示:
缺点:
上文提到的Spark使用的是同步阻断的方式进行更新,只有等所有节点的梯度计算完成后才能进行参数更新,会浪费大量时间;
对此,PS使用了异步非阻断的方式进行更新,当某个worker节点完成了push之后,其他节点没有进行push,这表示该节点无法进行pull来拉取新模型,此时该节点会直接再次进行计算,并且将当次计算的梯度在下一次push提交给server。
这种取舍虽然使得训练速度大幅增加,但是却造成模型的一致性有所丧失,具体影响还是有待讨论。
该方法可以通过最大延迟来限制这种异步操作,即某个worker节点最多只能超过server几轮迭代,超过之后必须停下来等待pull。
多server节点的协同和效率问题:
Spark效率低下的另一个原因是每次更新参数之后都需要使用单个master节点将模型参数广播至各个worker节点;
由于Parameter Server使用了多server node结构的server group,每个server node负责模型参数中的K-V对,并行地进行数据传播,大大加快了速度;
使用哈希一致性来保证每个server node负责对应的key range,并且能保证已有key range变化不大的情况下添加新的节点(看不懂)
Ring AllReduce
摒弃了使用server来进行输入传输,而是将各个worker连成环,进行循环传递来达到“混合”的效果,主要流程:
- 对于N个worker,将其连成一个环,并且将每个worker中的梯度分成N份;
- 对于第k个worker,其会将第k份梯度发送给下一个节点,同时从前一个节点收到第k-1份梯度;
- 将收到的k-1份梯度和原有的梯度整合,循环N次,这样每个节点都会包含所有节点梯度的对应的一份;
- 每个worker再将整合好的梯度发给下一个worker即可,需要注意的是,这里直接使用前一个worker的梯度覆盖当前的梯度,依然循环N次。
- 最后每个worker都会得到所有梯度,除以N即可进行参数更新。
更多算法请看此
NCCL
NCCL是Nvidia Collective multi-GPU Communication Library的简称,它是一个实现多GPU的collective。
pytorch code
pytorch的分布式训练目前仅支持linux系统。
pytorch数据分布式训练类型主要有:
- 单机单卡: 最简单的训练类型;
- 单机多卡: 代码改动较少;
- 多机多卡: 多台机器上的多张显卡,机制较为复杂;
- 其他: 其他的一些情况,不介绍。
DataParallel
单机多卡的情况,代码改动较少,主要基于nn.DataParallel
,是一种单进程多线程
的训练方式,存在GIL冲突。
torch.nn.DataParallel
定义如下:
class torch.nn.DataParallel(module, device_ids=None, output_device=None, dim=0)
复制代码
包含三个参数:
- module:表示需要Parallel的模型,实际上最终得到的是一个
nn.Module
- device_ids:表示训练用GPU的编号
- output_device:表示输出的device,用于汇总梯度和参数更新,默认选择0卡
- dim:表示处理loss的维度,默认0表示在batch上处理,使用了多少GPU就会返回多少个loss
不想使用0卡可以用如下方式指定训练设备
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" # 表示按照PCI_BUS_ID顺序从0开始排列GPU设备,不使用指定多卡时会报错
os.environ["CUDA_VISIBLE_DEVICES"] = "2, 3" # -1表示禁用GPU
复制代码
训练代码为:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
if torch.cuda.device_count()>1:
model = nn.DataParallel(model)
model.to(device)
复制代码
更优雅点可以这样:
if torch.cuda.device_count() > 1:
model = nn.DataParallel(model,device_ids=range(torch.cuda.device_count())) # 直接全用
复制代码
也可以使用如下方式指定训练GPU:
device_ids = [2, 3]
net = torch.nn.DataParallel(net, device_ids=device_ids)
复制代码
机制很简单,即将模型复制到各个GPU上进行forward和backward,由server卡汇总计算平均梯度,再分发更新后的参数至各个worker。这就导致了一个问题——负载不均衡,server卡的占用会很高,其原因主要是在server上计算loss,可以在每个gpu上计算loss之后返回给server求平均。
在保存和加载模型时
# 保存模型需要使用.module
torch.save(net.module.state_dict(), path)
# 或者
torch.save(net.module, path)
# 正常加载模型
net.load_state_dict(torch.load(path))
# 或者
net = torch.load(path)
复制代码
另一个需要注意的点是,使用DataParallel
时的batch_size
表示总的batch_size,每个GPU上会分到n分之一。
DistributedDataParallel
多机多卡的训练方式,多进程
不存在GIL冲突,也适用于单机单卡,并且速度较快。
DistributedDataParallel
需要一个init_process_group
的步骤来启动
class torch.distributed.init_process_group(backend,
init_method=None,
timeout=datetime.timedelta(0, 1800),
world_size=-1,
rank=-1,
store=None)
复制代码
- backend:str:gloo,mpi,nccl 。指定当前要使用的通信后端,通常使用 nccl;
- init_method:指定当前进程组的初始化方式;
- timeout:指定每个进程的超时时间,仅可用于"gloo"后端;
- world_size:总进程数;
- store:所有
worker
可访问的key
/value
,用于交换连接 / 地址信息。与init_method
互斥。
其他函数:
torch.distributed.get_backend(group=group) # group是可选参数,返回字符串表示的后端 group表示的是ProcessGroup类
torch.distributed.get_rank(group=group) # group是可选参数,返回int,执行该脚本的进程的rank
torch.distributed.get_world_size(group=group) # group是可选参数,返回全局的整个的进程数
torch.distributed.is_initialized() # 判断该进程是否已经初始化
torch.distributed.is_mpi_avaiable() # 判断MPI是否可用
torch.distributed.is_nccl_avaiable() # 判断nccl是否可用
复制代码
简单的使用:
初始化
torch.distributed.init_process_group(backend='nccl', init_method='env://')
复制代码
DistributedSampler
:
注意这里与DataParallel不同,batch_size的大小表示每个GPU上的大小,需要将dataloader切分。
train_sampler = torch.utils.data.distributed.DistributedSampler(train_data)
train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=False,
num_workers=2, pin_memory=True, sampler=train_sampler,)
复制代码
这里需要设置shuffle=False
,然后在每个epoch前,通过调用train_sampler.set_epoch(epoch)
来达到shuffle的效果.
模型的初始化:
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank])
复制代码
同步BN:
BN层对于较大的batch_size有更好的性能,所以对于BN需要使用所有卡上的数据来进行计算。
使用Apex:
from apex.parallel import convert_syncbn_model
from apex.parallel import DistributedDataParallel
# 注意顺序:三个顺序不能错
model = convert_syncbn_model(UNet3d(n_channels=1, n_classes=1)).to(device)
model, optimizer = amp.initialize(model, optimizer, opt_level='O1')
model = DistributedDataParallel(model, delay_allreduce=True)
复制代码
或者使用这里的代码来代替nn.BatchNorm
训练:
提供了torch.distributed.launch
用于启动训练
export CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch --nproc_per_node=4 train.py
复制代码
nproc_per_node参数指定为当前主机创建的进程数。一般设定为当前主机的 GPU
数量
模板:
import torch.distributed as dist
import torch.utils.data.distributed
# ......
parser = argparse.ArgumentParser(description='PyTorch distributed training on cifar-10')
parser.add_argument('--rank', default=0,
help='rank of current process')
parser.add_argument('--word_size', default=2,
help="word size")
parser.add_argument('--init_method', default='tcp://127.0.0.1:23456',
help="init-method")
args = parser.parse_args()
# ......
dist.init_process_group(backend='nccl', init_method=args.init_method, rank=args.rank, world_size=args.word_size)
# ......
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=download, transform=transform)
train_sampler = torch.utils.data.distributed.DistributedSampler(trainset)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, sampler=train_sampler)
# ......
net = Net()
net = net.cuda()
net = torch.nn.parallel.DistributedDataParallel(net)
复制代码