Pytorch DDP分布式数据合并通信 torch.distributed.all_gather()

1. 官网介绍

torch.distributed.all_gather() 官网链接

all_gather(tensor_list,tensor,group=None,async_op=False):

tensor_list每个元素代表每个rank的数据,tensor代表每个进程中的tensor数据,其中tensor_list每个分量的维度要与对应的tensor参数中每个rank的维度相同。

DDP
DDP

官网源代码:

def all_gather(tensor_list,
               tensor,
               group=None,
               async_op=False):
    """
    Gathers tensors from the whole group in a list.
    Complex tensors are supported.

    Args:
        tensor_list (list[Tensor]): Output list. It should contain
            correctly-sized tensors to be used for output of the collective.
        tensor (Tensor): Tensor to be broadcast from current process.
        group (ProcessGroup, optional): The process group to work on. If None,
            the default process group will be used.
        async_op (bool, optional): Whether this op should be an async op

    Returns:
        Async work handle, if async_op is set to True.
        None, if not async_op or if not part of the group

    Examples:
        >>> # All tensors below are of torch.int64 dtype.
        >>> # We have 2 process groups, 2 ranks.
        >>> tensor_list = [torch.zero(2, dtype=torch.int64) for _ in range(2)]
        >>> tensor_list
        [tensor([0, 0]), tensor([0, 0])] # Rank 0 and 1
        >>> tensor = torch.arange(2, dtype=torch.int64) + 1 + 2 * rank
        >>> tensor
        tensor([1, 2]) # Rank 0
        tensor([3, 4]) # Rank 1
        >>> dist.all_gather(tensor_list, tensor)
        >>> tensor_list
        [tensor([1, 2]), tensor([3, 4])] # Rank 0
        [tensor([1, 2]), tensor([3, 4])] # Rank 1

        >>> # All tensors below are of torch.cfloat dtype.
        >>> # We have 2 process groups, 2 ranks.
        >>> tensor_list = [torch.zero(2, dtype=torch.cfloat) for _ in range(2)]
        >>> tensor_list
        [tensor([0.+0.j, 0.+0.j]), tensor([0.+0.j, 0.+0.j])] # Rank 0 and 1
        >>> tensor = torch.tensor([1+1j, 2+2j], dtype=torch.cfloat) + 2 * rank * (1+1j)
        >>> tensor
        tensor([1.+1.j, 2.+2.j]) # Rank 0
        tensor([3.+3.j, 4.+4.j]) # Rank 1
        >>> dist.all_gather(tensor_list, tensor)
        >>> tensor_list
        [tensor([1.+1.j, 2.+2.j]), tensor([3.+3.j, 4.+4.j])] # Rank 0
        [tensor([1.+1.j, 2.+2.j]), tensor([3.+3.j, 4.+4.j])] # Rank 1

    """
    _check_tensor_list(tensor_list, "tensor_list")
    _check_single_tensor(tensor, "tensor")
    if _rank_not_in_group(group):
        return

    tensor_list = [t if not t.is_complex() else torch.view_as_real(t) for t in tensor_list]
    tensor = tensor if not tensor.is_complex() else torch.view_as_real(tensor)

    if group is None:
        default_pg = _get_default_group()
        work = default_pg.allgather([tensor_list], [tensor])
    else:
        work = group.allgather([tensor_list], [tensor])

    if async_op:
        return work
    else:
        work.wait()

官网例子:

# All tensors below are of torch.int64 dtype.
# We have 2 process groups, 2 ranks.
tensor_list = [torch.zeros(2, dtype=torch.int64) for _ in range(2)]
tensor_list
tensor = torch.arange(2, dtype=torch.int64) + 1 + 2 * rank
tensor
dist.all_gather(tensor_list, tensor)
tensor_list
# All tensors below are of torch.cfloat dtype.
# We have 2 process groups, 2 ranks.
tensor_list = [torch.zeros(2, dtype=torch.cfloat) for _ in range(2)]
tensor_list
tensor = torch.tensor([1+1j, 2+2j], dtype=torch.cfloat) + 2 * rank * (1+1j)
tensor
dist.all_gather(tensor_list, tensor)
tensor_list

2. all_gather()不进行梯度传播,用于模型test或eval状态

torch.distributed.all_gather 本身是不会进行梯度的反向传播的. 如下面代码

batch_size = 16
rank = int(os.environ.get('OMPI_COMM_WORLD_RANK', '0'))
world_size = int(os.environ.get('OMPI_COMM_WORLD_SIZE', '1'))
bs_each = batch_size // world_size
device_id = int(os.environ.get('OMPI_COMM_WORLD_LOCAL_RANK', '0'))
torch.cuda.set_device(device_id)
torch.distributed.init_process_group(
    backend='nccl',
    init_method='tcp://localhost:12345',
    rank=rank,
    world_size=world_size,
)
#
from torch import nn

model = nn.Linear(1, 1, bias=False)
model.weight.data[:] = 1.
model = model.cuda()
x = torch.ones((bs_each, 1), requires_grad=True).cuda()
y = model(x)
ys = [torch.zeros_like(y) for i in range(get_mpi_size())]
#
torch.distributed.all_gather(ys, y)
print(y.grad_fn)
#<MmBackward object at 0x7f2073fc3ba8>
for sub_y in ys:
     print(sub_y.grad_fn)
     #None

运行该代码,首先,其会打印出没采用 all_gather 的真正的梯度函数y.grad_fn. 然后,调用 all_gather 后,ys 的输出是没有 grad_fn 的,可以理解为其是没有梯度反向传播的.

实际场景中,推荐采用 torch.no_grad() 封装 all_gather 函数,以显式地表明没有梯度进行反向传播.

模板代码:

logits = torch.cat(logits_list, dim=0)
targets = torch.cat(targets_list, dim=0)

# For distributed parallel, collect all data and then run metrics.
if torch.distributed.is_initialized():
    logits_gather_list = [torch.zeros_like(logits) for _ in range(ngpus_per_node)]
    torch.distributed.all_gather(logits_gather_list, logits)
    logits = torch.cat(logits_gather_list, dim=0)

    targets_gather_list = [torch.zeros_like(targets) for _ in range(ngpus_per_node)]
    torch.distributed.all_gather(targets_gather_list, targets)
    targets = torch.cat(targets_gather_list, dim=0)

accuracy, recall, precision, auc = classification_metrics(logits, targets)

3. all_gather()需要进行梯度传播,用于模型train状态

with torch.no_grad():
    all_x = [torch.zeros_like(x) for _ in range(world_size)]
    torch.distributed.all_gather(all_x, x)
all_x[rank] = x

all_x 包含了所有 GPUs 输出的 x. 所有的 x 都是没有 grad_fn 的,除了当前 GPU 输出的 x,因为 all_x[rank] = x。 然后,即可基于 all_x 和 f 计算损失。

也就是把当前GPU上面的原tensor数据赋值给all_x相应rank索引的地方,从而使all_x[rank]的tensor数据能够计算梯度,从而(所有gpu能够进行)反向传播。

模板代码:

import torch
import torch.distributed as dist

# Dummy code representing the forward pass for some batch of text on one replica.
embeddings = model(batch)

# Gather the embeddings from every replica.
embeddings_list = [torch.ones_like(embeddings) for _ in range(dist.get_world_size())]
dist.all_gather(embeddings_list, embeddings)

# Gathered tensors have no gradient, so we overwrite the gathered tensor for the current replica.
# with the embeddings produced on this replica, which do have gradients.
embeddings_list[dist.get_rank()] = embeddings

# Finally, concatenate the list of embeddings before computing a loss.
embeddings = torch.cat(embeddings_list)

# I didn't demonstrate how to generate the labels, this will be task-dependent.
loss = some_contrastive_loss(embeddings, labels)

参考链接:
https://github.com/KevinMusgrave/pytorch-metric-learning/issues/10

下面这三个含梯度的all_gather代码也能实现(SimCLR模型的代码):
1.

class GatherLayer(torch.autograd.Function):
    '''Gather tensors from all process, supporting backward propagation.
    '''

    @staticmethod
    def forward(ctx, input):
        ctx.save_for_backward(input)
        output = [torch.zeros_like(input) \
            for _ in range(dist.get_world_size())]

        dist.all_gather(output, input)
        return tuple(output)

    @staticmethod
    def backward(ctx, *grads):
        input, = ctx.saved_tensors
        grad_out = torch.zeros_like(input)
        grad_out[:] = grads[dist.get_rank()]
        return grad_out

使用方式:

allgather = GatherLayer.apply
features_gather = allgather(features)  #多张GPU的数据gather到一起

参考链接:
https://i.steer.space/blog/2021/01/pytorch-dist-nccl-backend-allgather-stuck

class SyncFunction(torch.autograd.Function):

    @staticmethod
    def forward(ctx, tensor):
        ctx.batch_size = tensor.shape[0]

        gathered_tensor = [torch.zeros_like(tensor) for _ in range(torch.distributed.get_world_size())]

        torch.distributed.all_gather(gathered_tensor, tensor)
        gathered_tensor = torch.cat(gathered_tensor, 0)

        return gathered_tensor

    @staticmethod
    def backward(ctx, grad_output):
        grad_input = grad_output.clone()
        torch.distributed.all_reduce(grad_input, op=torch.distributed.ReduceOp.SUM, async_op=False)

        idx_from = torch.distributed.get_rank() * ctx.batch_size
        idx_to = (torch.distributed.get_rank() + 1) * ctx.batch_size
        return grad_input[idx_from:idx_to]

使用方式:

allgather = SyncFunction.apply
features_gather = allgather(features)  #多张GPU的数据gather到一起

参考链接:https://github.com/Lightning-AI/lightning-bolts/blob/5577453a6d7072724d9ae24184daf8f45d4baff7/pl_bolts/models/self_supervised/simclr/simclr_module.py

import torch.distributed as dist
class AllGather(torch.autograd.Function):
    """An autograd function that performs allgather on a tensor."""

    @staticmethod
    def forward(ctx, tensor):
        output = [torch.empty_like(tensor) for _ in range(dist.get_world_size())]
        torch.distributed.all_gather(output, tensor)
        ctx.rank = dist.get_rank()
        ctx.batch_size = tensor.shape[0]
        return torch.cat(output, dim=0)

    @staticmethod
    def backward(ctx, grad_output):
        return (
            grad_output[ctx.batch_size * ctx.rank : ctx.batch_size * (ctx.rank + 1)],
            None,
        )

使用方式:

allgather = AllGather.apply
features_gather = allgather(features) #多张GPU的数据gather到一起

参考链接:https://github.com/ArrowLuo/CLIP4Clip

4. 相关链接

  1. 官方链接:torch.distributed.all_gather()
  2. Pytorch - 基于torch.distributed.all_gather的梯度反向传播
  3. PyTorch 多进程分布式训练实战
  4. 在NCCL后端下Pytorch的distributed.all_gather卡死排查
  5. PyTorch分布式DPP涉及的基本概念与问题
  6. PyTorch分布式训练详解教程 scatter, gather & isend, irecv & all_reduce & DDP
  7. 图解DistributedDataParallel (DDP)的通信方式:gather,all_gather,all_reduce,reduce,scatter

猜你喜欢

转载自blog.csdn.net/flyingluohaipeng/article/details/129134552