Article Directory
1. Introduction to the official website
torch.distributed.all_gather() official website link
all_gather(tensor_list,tensor,group=None,async_op=False):
Each element of tensor_list represents the data of each rank, and tensor represents the tensor data in each process. The dimension of each component of tensor_list must be the same as the dimension of each rank in the corresponding tensor parameter.
Official website source code:
def all_gather(tensor_list,
tensor,
group=None,
async_op=False):
"""
Gathers tensors from the whole group in a list.
Complex tensors are supported.
Args:
tensor_list (list[Tensor]): Output list. It should contain
correctly-sized tensors to be used for output of the collective.
tensor (Tensor): Tensor to be broadcast from current process.
group (ProcessGroup, optional): The process group to work on. If None,
the default process group will be used.
async_op (bool, optional): Whether this op should be an async op
Returns:
Async work handle, if async_op is set to True.
None, if not async_op or if not part of the group
Examples:
>>> # All tensors below are of torch.int64 dtype.
>>> # We have 2 process groups, 2 ranks.
>>> tensor_list = [torch.zero(2, dtype=torch.int64) for _ in range(2)]
>>> tensor_list
[tensor([0, 0]), tensor([0, 0])] # Rank 0 and 1
>>> tensor = torch.arange(2, dtype=torch.int64) + 1 + 2 * rank
>>> tensor
tensor([1, 2]) # Rank 0
tensor([3, 4]) # Rank 1
>>> dist.all_gather(tensor_list, tensor)
>>> tensor_list
[tensor([1, 2]), tensor([3, 4])] # Rank 0
[tensor([1, 2]), tensor([3, 4])] # Rank 1
>>> # All tensors below are of torch.cfloat dtype.
>>> # We have 2 process groups, 2 ranks.
>>> tensor_list = [torch.zero(2, dtype=torch.cfloat) for _ in range(2)]
>>> tensor_list
[tensor([0.+0.j, 0.+0.j]), tensor([0.+0.j, 0.+0.j])] # Rank 0 and 1
>>> tensor = torch.tensor([1+1j, 2+2j], dtype=torch.cfloat) + 2 * rank * (1+1j)
>>> tensor
tensor([1.+1.j, 2.+2.j]) # Rank 0
tensor([3.+3.j, 4.+4.j]) # Rank 1
>>> dist.all_gather(tensor_list, tensor)
>>> tensor_list
[tensor([1.+1.j, 2.+2.j]), tensor([3.+3.j, 4.+4.j])] # Rank 0
[tensor([1.+1.j, 2.+2.j]), tensor([3.+3.j, 4.+4.j])] # Rank 1
"""
_check_tensor_list(tensor_list, "tensor_list")
_check_single_tensor(tensor, "tensor")
if _rank_not_in_group(group):
return
tensor_list = [t if not t.is_complex() else torch.view_as_real(t) for t in tensor_list]
tensor = tensor if not tensor.is_complex() else torch.view_as_real(tensor)
if group is None:
default_pg = _get_default_group()
work = default_pg.allgather([tensor_list], [tensor])
else:
work = group.allgather([tensor_list], [tensor])
if async_op:
return work
else:
work.wait()
Official website example:
# All tensors below are of torch.int64 dtype.
# We have 2 process groups, 2 ranks.
tensor_list = [torch.zeros(2, dtype=torch.int64) for _ in range(2)]
tensor_list
tensor = torch.arange(2, dtype=torch.int64) + 1 + 2 * rank
tensor
dist.all_gather(tensor_list, tensor)
tensor_list
# All tensors below are of torch.cfloat dtype.
# We have 2 process groups, 2 ranks.
tensor_list = [torch.zeros(2, dtype=torch.cfloat) for _ in range(2)]
tensor_list
tensor = torch.tensor([1+1j, 2+2j], dtype=torch.cfloat) + 2 * rank * (1+1j)
tensor
dist.all_gather(tensor_list, tensor)
tensor_list
2. all_gather() does not perform gradient propagation and is used for model test or eval status
torch.distributed.all_gather itself does not carry out gradient backpropagation. Such as the following code
batch_size = 16
rank = int(os.environ.get('OMPI_COMM_WORLD_RANK', '0'))
world_size = int(os.environ.get('OMPI_COMM_WORLD_SIZE', '1'))
bs_each = batch_size // world_size
device_id = int(os.environ.get('OMPI_COMM_WORLD_LOCAL_RANK', '0'))
torch.cuda.set_device(device_id)
torch.distributed.init_process_group(
backend='nccl',
init_method='tcp://localhost:12345',
rank=rank,
world_size=world_size,
)
#
from torch import nn
model = nn.Linear(1, 1, bias=False)
model.weight.data[:] = 1.
model = model.cuda()
x = torch.ones((bs_each, 1), requires_grad=True).cuda()
y = model(x)
ys = [torch.zeros_like(y) for i in range(get_mpi_size())]
#
torch.distributed.all_gather(ys, y)
print(y.grad_fn)
#<MmBackward object at 0x7f2073fc3ba8>
for sub_y in ys:
print(sub_y.grad_fn)
#None
Run this code, first, it will print out the real gradient function y.grad_fn that does not use all_gather. Then, after calling all_gather, the output of ys does not have grad_fn, which can be understood as no gradient backpropagation .
In actual scenarios, it is recommended to use torch.no_grad() to encapsulate the all_gather function to explicitly indicate that there is no gradient for backpropagation.
Template code:
logits = torch.cat(logits_list, dim=0)
targets = torch.cat(targets_list, dim=0)
# For distributed parallel, collect all data and then run metrics.
if torch.distributed.is_initialized():
logits_gather_list = [torch.zeros_like(logits) for _ in range(ngpus_per_node)]
torch.distributed.all_gather(logits_gather_list, logits)
logits = torch.cat(logits_gather_list, dim=0)
targets_gather_list = [torch.zeros_like(targets) for _ in range(ngpus_per_node)]
torch.distributed.all_gather(targets_gather_list, targets)
targets = torch.cat(targets_gather_list, dim=0)
accuracy, recall, precision, auc = classification_metrics(logits, targets)
3. all_gather() needs gradient propagation for model train state
with torch.no_grad():
all_x = [torch.zeros_like(x) for _ in range(world_size)]
torch.distributed.all_gather(all_x, x)
all_x[rank] = x
all_x contains x output by all GPUs. All x are without grad_fn, except x output by the current GPU, because all_x[rank] = x. Then, the loss can be calculated based on all_x and f.
That is, the place where the original tensor data on the current GPU is assigned to the corresponding rank index of all_x, so that the tensor data of all_x[rank] can calculate the gradient, so that (all GPUs can perform) backpropagation.
Template code:
import torch
import torch.distributed as dist
# Dummy code representing the forward pass for some batch of text on one replica.
embeddings = model(batch)
# Gather the embeddings from every replica.
embeddings_list = [torch.ones_like(embeddings) for _ in range(dist.get_world_size())]
dist.all_gather(embeddings_list, embeddings)
# Gathered tensors have no gradient, so we overwrite the gathered tensor for the current replica.
# with the embeddings produced on this replica, which do have gradients.
embeddings_list[dist.get_rank()] = embeddings
# Finally, concatenate the list of embeddings before computing a loss.
embeddings = torch.cat(embeddings_list)
# I didn't demonstrate how to generate the labels, this will be task-dependent.
loss = some_contrastive_loss(embeddings, labels)
Reference link:
https://github.com/KevinMusgrave/pytorch-metric-learning/issues/10
The following three all_gather codes with gradients can also be implemented (the code of the SimCLR model):
1.
class GatherLayer(torch.autograd.Function):
'''Gather tensors from all process, supporting backward propagation.
'''
@staticmethod
def forward(ctx, input):
ctx.save_for_backward(input)
output = [torch.zeros_like(input) \
for _ in range(dist.get_world_size())]
dist.all_gather(output, input)
return tuple(output)
@staticmethod
def backward(ctx, *grads):
input, = ctx.saved_tensors
grad_out = torch.zeros_like(input)
grad_out[:] = grads[dist.get_rank()]
return grad_out
How to use:
allgather = GatherLayer.apply
features_gather = allgather(features) #多张GPU的数据gather到一起
Reference link:
https://i.steer.space/blog/2021/01/pytorch-dist-nccl-backend-allgather-stuck
class SyncFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, tensor):
ctx.batch_size = tensor.shape[0]
gathered_tensor = [torch.zeros_like(tensor) for _ in range(torch.distributed.get_world_size())]
torch.distributed.all_gather(gathered_tensor, tensor)
gathered_tensor = torch.cat(gathered_tensor, 0)
return gathered_tensor
@staticmethod
def backward(ctx, grad_output):
grad_input = grad_output.clone()
torch.distributed.all_reduce(grad_input, op=torch.distributed.ReduceOp.SUM, async_op=False)
idx_from = torch.distributed.get_rank() * ctx.batch_size
idx_to = (torch.distributed.get_rank() + 1) * ctx.batch_size
return grad_input[idx_from:idx_to]
How to use:
allgather = SyncFunction.apply
features_gather = allgather(features) #多张GPU的数据gather到一起
import torch.distributed as dist
class AllGather(torch.autograd.Function):
"""An autograd function that performs allgather on a tensor."""
@staticmethod
def forward(ctx, tensor):
output = [torch.empty_like(tensor) for _ in range(dist.get_world_size())]
torch.distributed.all_gather(output, tensor)
ctx.rank = dist.get_rank()
ctx.batch_size = tensor.shape[0]
return torch.cat(output, dim=0)
@staticmethod
def backward(ctx, grad_output):
return (
grad_output[ctx.batch_size * ctx.rank : ctx.batch_size * (ctx.rank + 1)],
None,
)
How to use:
allgather = AllGather.apply
features_gather = allgather(features) #多张GPU的数据gather到一起
Reference link: https://github.com/ArrowLuo/CLIP4Clip
4. Related Links
- Official link: torch.distributed.all_gather()
- Pytorch - gradient backpropagation based on torch.distributed.all_gather
- PyTorch multi-process distributed training practice
- Pytorch's distributed.all_gather stuck troubleshooting under NCCL backend
- Basic concepts and issues involved in PyTorch distributed DPP
- PyTorch distributed training detailed tutorial scatter, gather & isend, irecv & all_reduce & DDP
- Graphical DistributedDataParallel (DDP) communication method: gather, all_gather, all_reduce, reduce, scatter