Torch distributed training

introduce

torch.nn.DataParallel

torch.nn.DataParallelis a module in PyTorch that can be used to train neural networks in parallel on multiple GPUs. Specifically, it can copy a single model to multiple GPUs, run the same operation on each GPU, and finally sum the gradients on each GPU and update the model parameters. In this way, the training process of the neural network can be significantly accelerated.

Using torch.nn.DataParallel is easy. Just wrap the model in torch.nn.DataParallel when defining the model. For example:

import torch.nn as nn

model = nn.DataParallel(MyModel())

This will replicate MyModel() to multiple GPUs and run the same operation in parallel on each GPU.

It should be noted that if you are using PyTorch 1.6 and above, you don't have to use torch.nn.DataParallel, because PyTorch has built-in higher-level distributed training modules, such as torch.nn.parallel.DistributedDataParallel. These modules provide better performance and more flexible configuration options, which can better meet the needs of various distributed training.

torch.nn.parallel.DistributedDataParallel

torch.nn.parallel.DistributedDataParallelis a module in PyTorch that can be used to train neural networks in parallel in a distributed environment. Different from torch.nn.DataParallel, torch.nn.parallel.DistributedDataParallel it can support cross-process and cross-machine distributed training, and can train neural networks on multiple computers at the same time, which can significantly speed up the training process.

Use torch.nn.parallel.DistributedDataParallelrequires the following steps:

  1. Start a process group: In distributed training, a process group (process group) needs to be used for communication between processes. You can use torch.distributed.init_process_group()a function to start a process group, you need to specify the type of process group (such as torch.distributed.Backend.GLOO or torch.distributed.Backend.NCCL), the name of the process group, the number of processes in the process group, the number of the current process, etc. parameter.

  2. Loading data sets: In distributed training, each process needs to read a part of the data set, and the data set needs to be divided to ensure that the data read by each process is not repeated or omitted. You can use the DistributedSampler provided by PyTorch to realize the division of the dataset, and you can also use the DataLoader to load the dataset.

  3. Define the model: In distributed training, it is necessary to ensure that the model can be correctly initialized in each process. You can define the same model in each process, or define the model in the main process and then use the torch.nn.parallel.DistributedDataParallelpackage provided by PyTorch.

  4. Training model: In distributed training, it is necessary to ensure that each process can perform forward propagation, back propagation and parameter update in parallel. You can use the backward() and step() functions provided by PyTorch to implement backpropagation and parameter update, and you can also use the all_reduce() function to sum the gradients of each process.

  5. Ending training: In distributed training, it is necessary to ensure that process groups end properly. A process group can be shut down using torch.distributed.destroy_process_group() the function .

It should be noted that the use of torch.nn.parallel.DistributedDataParallel requires certain modifications to the code, such as adding steps such as starting process groups, loading data sets, and defining models. At the same time, issues such as data division and gradient synchronization need to be considered. Therefore, using torch.nn.parallel.DistributedDataParallel requires certain distributed programming knowledge and experience.

example

import torch
import torch.distributed as dist
import torch.nn as nn
import torch.optim as optim
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.utils.data.distributed import DistributedSampler
from torch.nn.parallel import DistributedDataParallel

# 启动进程组
dist.init_process_group(backend='gloo', init_method='file:///tmp/some_file', world_size=4, rank=0)

# 加载数据集
train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transforms.ToTensor())
train_sampler = DistributedSampler(train_dataset)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=128, shuffle=False, num_workers=2, sampler=train_sampler)

# 定义模型
model = nn.Sequential(
    nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1),
    nn.ReLU(),
    nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),
    nn.ReLU(),
    nn.MaxPool2d(kernel_size=2, stride=2),
    nn.Flatten(),
    nn.Linear(64 * 16 * 16, 10)
)
model = DistributedDataParallel(model)

# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)

# 训练模型
for epoch in range(10):
    train_sampler.set_epoch(epoch)
    for i, (inputs, labels) in enumerate(train_loader):
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

# 结束训练
dist.destroy_process_group()

In this example, the CIFAR10 dataset is used, a simple convolutional neural network model is defined, and the model is encapsulated using torch.nn.parallel.DistributedDataParallel. Then use DistributedSampler to divide the data set, and use DataLoader to load the data set. During the training process, the backward() and step() functions are used for backpropagation and parameter update, and the all_reduce() function is used to sum the gradients of each process. Finally, use the torch.distributed.destroy_process_group() function to end the process group.

Guess you like

Origin blog.csdn.net/frighting_ing/article/details/130322905