Pytorch DDP distributed data merge communication torch.distributed.all_gather()

1. Introduction to the official website

torch.distributed.all_gather() official website link

all_gather(tensor_list,tensor,group=None,async_op=False):

Each element of tensor_list represents the data of each rank, and tensor represents the tensor data in each process. The dimension of each component of tensor_list must be the same as the dimension of each rank in the corresponding tensor parameter.

DDP
DDP

Official website source code:

def all_gather(tensor_list,
               tensor,
               group=None,
               async_op=False):
    """
    Gathers tensors from the whole group in a list.
    Complex tensors are supported.

    Args:
        tensor_list (list[Tensor]): Output list. It should contain
            correctly-sized tensors to be used for output of the collective.
        tensor (Tensor): Tensor to be broadcast from current process.
        group (ProcessGroup, optional): The process group to work on. If None,
            the default process group will be used.
        async_op (bool, optional): Whether this op should be an async op

    Returns:
        Async work handle, if async_op is set to True.
        None, if not async_op or if not part of the group

    Examples:
        >>> # All tensors below are of torch.int64 dtype.
        >>> # We have 2 process groups, 2 ranks.
        >>> tensor_list = [torch.zero(2, dtype=torch.int64) for _ in range(2)]
        >>> tensor_list
        [tensor([0, 0]), tensor([0, 0])] # Rank 0 and 1
        >>> tensor = torch.arange(2, dtype=torch.int64) + 1 + 2 * rank
        >>> tensor
        tensor([1, 2]) # Rank 0
        tensor([3, 4]) # Rank 1
        >>> dist.all_gather(tensor_list, tensor)
        >>> tensor_list
        [tensor([1, 2]), tensor([3, 4])] # Rank 0
        [tensor([1, 2]), tensor([3, 4])] # Rank 1

        >>> # All tensors below are of torch.cfloat dtype.
        >>> # We have 2 process groups, 2 ranks.
        >>> tensor_list = [torch.zero(2, dtype=torch.cfloat) for _ in range(2)]
        >>> tensor_list
        [tensor([0.+0.j, 0.+0.j]), tensor([0.+0.j, 0.+0.j])] # Rank 0 and 1
        >>> tensor = torch.tensor([1+1j, 2+2j], dtype=torch.cfloat) + 2 * rank * (1+1j)
        >>> tensor
        tensor([1.+1.j, 2.+2.j]) # Rank 0
        tensor([3.+3.j, 4.+4.j]) # Rank 1
        >>> dist.all_gather(tensor_list, tensor)
        >>> tensor_list
        [tensor([1.+1.j, 2.+2.j]), tensor([3.+3.j, 4.+4.j])] # Rank 0
        [tensor([1.+1.j, 2.+2.j]), tensor([3.+3.j, 4.+4.j])] # Rank 1

    """
    _check_tensor_list(tensor_list, "tensor_list")
    _check_single_tensor(tensor, "tensor")
    if _rank_not_in_group(group):
        return

    tensor_list = [t if not t.is_complex() else torch.view_as_real(t) for t in tensor_list]
    tensor = tensor if not tensor.is_complex() else torch.view_as_real(tensor)

    if group is None:
        default_pg = _get_default_group()
        work = default_pg.allgather([tensor_list], [tensor])
    else:
        work = group.allgather([tensor_list], [tensor])

    if async_op:
        return work
    else:
        work.wait()

Official website example:

# All tensors below are of torch.int64 dtype.
# We have 2 process groups, 2 ranks.
tensor_list = [torch.zeros(2, dtype=torch.int64) for _ in range(2)]
tensor_list
tensor = torch.arange(2, dtype=torch.int64) + 1 + 2 * rank
tensor
dist.all_gather(tensor_list, tensor)
tensor_list
# All tensors below are of torch.cfloat dtype.
# We have 2 process groups, 2 ranks.
tensor_list = [torch.zeros(2, dtype=torch.cfloat) for _ in range(2)]
tensor_list
tensor = torch.tensor([1+1j, 2+2j], dtype=torch.cfloat) + 2 * rank * (1+1j)
tensor
dist.all_gather(tensor_list, tensor)
tensor_list

2. all_gather() does not perform gradient propagation and is used for model test or eval status

torch.distributed.all_gather itself does not carry out gradient backpropagation. Such as the following code

batch_size = 16
rank = int(os.environ.get('OMPI_COMM_WORLD_RANK', '0'))
world_size = int(os.environ.get('OMPI_COMM_WORLD_SIZE', '1'))
bs_each = batch_size // world_size
device_id = int(os.environ.get('OMPI_COMM_WORLD_LOCAL_RANK', '0'))
torch.cuda.set_device(device_id)
torch.distributed.init_process_group(
    backend='nccl',
    init_method='tcp://localhost:12345',
    rank=rank,
    world_size=world_size,
)
#
from torch import nn

model = nn.Linear(1, 1, bias=False)
model.weight.data[:] = 1.
model = model.cuda()
x = torch.ones((bs_each, 1), requires_grad=True).cuda()
y = model(x)
ys = [torch.zeros_like(y) for i in range(get_mpi_size())]
#
torch.distributed.all_gather(ys, y)
print(y.grad_fn)
#<MmBackward object at 0x7f2073fc3ba8>
for sub_y in ys:
     print(sub_y.grad_fn)
     #None

Run this code, first, it will print out the real gradient function y.grad_fn that does not use all_gather. Then, after calling all_gather, the output of ys does not have grad_fn, which can be understood as no gradient backpropagation .

In actual scenarios, it is recommended to use torch.no_grad() to encapsulate the all_gather function to explicitly indicate that there is no gradient for backpropagation.

Template code:

logits = torch.cat(logits_list, dim=0)
targets = torch.cat(targets_list, dim=0)

# For distributed parallel, collect all data and then run metrics.
if torch.distributed.is_initialized():
    logits_gather_list = [torch.zeros_like(logits) for _ in range(ngpus_per_node)]
    torch.distributed.all_gather(logits_gather_list, logits)
    logits = torch.cat(logits_gather_list, dim=0)

    targets_gather_list = [torch.zeros_like(targets) for _ in range(ngpus_per_node)]
    torch.distributed.all_gather(targets_gather_list, targets)
    targets = torch.cat(targets_gather_list, dim=0)

accuracy, recall, precision, auc = classification_metrics(logits, targets)

3. all_gather() needs gradient propagation for model train state

with torch.no_grad():
    all_x = [torch.zeros_like(x) for _ in range(world_size)]
    torch.distributed.all_gather(all_x, x)
all_x[rank] = x

all_x contains x output by all GPUs. All x are without grad_fn, except x output by the current GPU, because all_x[rank] = x. Then, the loss can be calculated based on all_x and f.

That is, the place where the original tensor data on the current GPU is assigned to the corresponding rank index of all_x, so that the tensor data of all_x[rank] can calculate the gradient, so that (all GPUs can perform) backpropagation.

Template code:

import torch
import torch.distributed as dist

# Dummy code representing the forward pass for some batch of text on one replica.
embeddings = model(batch)

# Gather the embeddings from every replica.
embeddings_list = [torch.ones_like(embeddings) for _ in range(dist.get_world_size())]
dist.all_gather(embeddings_list, embeddings)

# Gathered tensors have no gradient, so we overwrite the gathered tensor for the current replica.
# with the embeddings produced on this replica, which do have gradients.
embeddings_list[dist.get_rank()] = embeddings

# Finally, concatenate the list of embeddings before computing a loss.
embeddings = torch.cat(embeddings_list)

# I didn't demonstrate how to generate the labels, this will be task-dependent.
loss = some_contrastive_loss(embeddings, labels)

Reference link:
https://github.com/KevinMusgrave/pytorch-metric-learning/issues/10

The following three all_gather codes with gradients can also be implemented (the code of the SimCLR model):
1.

class GatherLayer(torch.autograd.Function):
    '''Gather tensors from all process, supporting backward propagation.
    '''

    @staticmethod
    def forward(ctx, input):
        ctx.save_for_backward(input)
        output = [torch.zeros_like(input) \
            for _ in range(dist.get_world_size())]

        dist.all_gather(output, input)
        return tuple(output)

    @staticmethod
    def backward(ctx, *grads):
        input, = ctx.saved_tensors
        grad_out = torch.zeros_like(input)
        grad_out[:] = grads[dist.get_rank()]
        return grad_out

How to use:

allgather = GatherLayer.apply
features_gather = allgather(features)  #多张GPU的数据gather到一起

Reference link:
https://i.steer.space/blog/2021/01/pytorch-dist-nccl-backend-allgather-stuck

class SyncFunction(torch.autograd.Function):

    @staticmethod
    def forward(ctx, tensor):
        ctx.batch_size = tensor.shape[0]

        gathered_tensor = [torch.zeros_like(tensor) for _ in range(torch.distributed.get_world_size())]

        torch.distributed.all_gather(gathered_tensor, tensor)
        gathered_tensor = torch.cat(gathered_tensor, 0)

        return gathered_tensor

    @staticmethod
    def backward(ctx, grad_output):
        grad_input = grad_output.clone()
        torch.distributed.all_reduce(grad_input, op=torch.distributed.ReduceOp.SUM, async_op=False)

        idx_from = torch.distributed.get_rank() * ctx.batch_size
        idx_to = (torch.distributed.get_rank() + 1) * ctx.batch_size
        return grad_input[idx_from:idx_to]

How to use:

allgather = SyncFunction.apply
features_gather = allgather(features)  #多张GPU的数据gather到一起

Reference link: https://github.com/Lightning-AI/lightning-bolts/blob/5577453a6d7072724d9ae24184daf8f45d4baff7/pl_bolts/models/self_supervised/simclr/simclr_module.py

import torch.distributed as dist
class AllGather(torch.autograd.Function):
    """An autograd function that performs allgather on a tensor."""

    @staticmethod
    def forward(ctx, tensor):
        output = [torch.empty_like(tensor) for _ in range(dist.get_world_size())]
        torch.distributed.all_gather(output, tensor)
        ctx.rank = dist.get_rank()
        ctx.batch_size = tensor.shape[0]
        return torch.cat(output, dim=0)

    @staticmethod
    def backward(ctx, grad_output):
        return (
            grad_output[ctx.batch_size * ctx.rank : ctx.batch_size * (ctx.rank + 1)],
            None,
        )

How to use:

allgather = AllGather.apply
features_gather = allgather(features) #多张GPU的数据gather到一起

Reference link: https://github.com/ArrowLuo/CLIP4Clip

4. Related Links

  1. Official link: torch.distributed.all_gather()
  2. Pytorch - gradient backpropagation based on torch.distributed.all_gather
  3. PyTorch multi-process distributed training practice
  4. Pytorch's distributed.all_gather stuck troubleshooting under NCCL backend
  5. Basic concepts and issues involved in PyTorch distributed DPP
  6. PyTorch distributed training detailed tutorial scatter, gather & isend, irecv & all_reduce & DDP
  7. Graphical DistributedDataParallel (DDP) communication method: gather, all_gather, all_reduce, reduce, scatter

Guess you like

Origin blog.csdn.net/flyingluohaipeng/article/details/129134552