Pytorch multi-GPU DataParallel y la acumulación de gradiente resuelven el problema del desequilibrio de la memoria y la memoria insuficiente

  Al realizar experimentos de clasificación de imágenes recientemente, utilicé la función DataParallel de pytorch para ejecutar el programa en paralelo en 4 GPU. Cuando el tamaño del lote es 16, se informará el siguiente error: RuntimeError: CUDA sin memoria. Se intentó asignar 858,00 MiB (GPU
  3; 10,92 GiB de capacidad total; 10,10 GiB ya asignados; 150,69 MiB libres; 10,13 GiB reservados en total por PyTorch)

  El experimento encontró que cada GPU puede ejecutar hasta 2 datos, pero quiero establecer el tamaño de lote = 16. Consulte https://zhuanlan.zhihu.com/p/86441879 para obtener información sobre la función BalancedDataParallel escrita oficialmente por transformador. XL para resolver el problema: Problema de desequilibrio en el uso de la memoria de video de DataParallel (consulte el código de referencia al final).
  Para comprender el uso de la función BalancedDataParallel, primero aclaremos algunas cuestiones.
1. ¿Cómo funciona la función DataParallel?
  Primero cargue el modelo en la GPU maestra, luego copie el modelo en cada GPU esclava designada y luego divida los datos de entrada de acuerdo con la dimensión del lote. Específicamente, el número de lotes de datos asignados a cada GPU son los datos de entrada totales. se divide por el número especificado de GPU. Cada GPU realizará de forma independiente cálculos directos en sus propios datos de entrada y luego transferirá los resultados del cálculo a la GPU principal para completar el cálculo del gradiente y la actualización de parámetros. Finalmente, los parámetros actualizados se copiarán a la GPU esclava, completando así un cálculo iterativo. . Consulte https://blog.csdn.net/zhjm07054115/article/details/104799661 Cuando gpu = 2, lote_size = 30, podemos ver claramente en la figura siguiente que los primeros 15 datos se asignarán en las dos gpu respectivamente. Realice cálculos directos y luego resuma los resultados para el cálculo del gradiente y la actualización de parámetros.
  Podemos ver que el cálculo de retropropagación y la actualización de parámetros se realizan completamente en la GPU principal, lo que causará el problema del uso desequilibrado de la memoria de video.
Insertar descripción de la imagen aquí
2.
  Referencia de acumulación de gradientehttps://blog.csdn.net/wuzhongqiang/article/details/102572324 realizó un experimento de acumulación de gradiente de retropropagación y descubrió que cuando Pytorch realiza la retropropagación, acumula el último gradiente de forma predeterminada. Si no desea la última vez Si el El gradiente afecta su cálculo de gradiente esta vez, debe borrarlo manualmente.

  Después de comprender la función DataParallel y la acumulación de gradiente, podemos resolver el problema del uso desequilibrado de la memoria de video y cómo aumentar el lote de entrenamiento con memoria de video fija.
  En primer lugar, una breve introducción al uso de BalancedDataParallel [La siguiente imagen está tomada de https://github.com/Link-Li/Balanced-DataParallel ]
Insertar descripción de la imagen aquí
  Una breve explicación: cuando necesitamos ejecutar el programa en paralelo en 3 GPU, cada GPU puede procesar como máximo 3 datos a la vez, la asignación es [3,3,3], luego 3 GPU pueden procesar hasta 9 datos al mismo tiempo, es decir, el tamaño de lote máximo puede debe establecerse en 9, porque la propagación hacia atrás también se realiza en la GPU principal, por lo que aquí configuramos la GPU principal que procesa 2 datos, la distribución es [2,3,3], tamaño de lote = 8.
  En este momento, si queremos aumentar el tamaño del lote para que el tamaño del lote = 16, entonces la distribución debe ser [4, 6, 6], pero sabemos que cada GPU puede procesar hasta 3 datos, por lo que la acumulación de gradiente Aquí se usa el método, es decir, el acc_grad en la figura anterior, los parámetros acc_grad indican en cuántas partes se divide el tamaño del lote y se envía a la red. Cuando acc_grad = 2, significa que primero dividiremos los 16 datos en 2 partes , Cada parte tiene 8 datos, y cada vez que se ingresan 8 datos, las tres GPU se usan para el entrenamiento en paralelo y los resultados del cálculo directo se colocan en la GPU principal para la propagación hacia atrás. Dado que el gradiente se puede acumular, los parámetros se actualizan después de dos ciclos. Hacerlo no sólo puede aliviar el problema del desequilibrio de la memoria de video sino también resolver el problema de la memoria de video insuficiente.
  El siguiente es el código completo que modifiqué según https://blog.csdn.net/zhjm07054115/article/details/104799661 y agregué BalancedDataParallel:

import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from data_parallel_balance import BalancedDataParallel

# Dataset
class RandomDataset(Dataset):
    def __init__(self, size, length):
        self.len = length
        self.data = torch.randn(length, size)
        self.target=np.random.randint(3,size=length)
    def __getitem__(self, index):
        label=torch.tensor(self.target[index])
        return self.data[index],label
    def __len__(self):
        return self.len
        
# model        
class Model(nn.Module):
    def __init__(self, input_size, output_size):
        super(Model, self).__init__()
        self.fc = nn.Linear(input_size, output_size)

    def forward(self, input):
        output = self.fc(input)
        print("\tIn Model: input size", input.size(),
              "output size", output.size())
        return output
        
# trian
def train(rand_loader,model,optimizer,criterion):
    train_loss=0
   
    # train
    model.train()
    optimizer.zero_grad()
    for image,target in rand_loader:
        print('image:',image.shape)
        if batch_chunk > 0:
            image_chunks = torch.chunk(image, batch_chunk, 0)
            target_chunks = torch.chunk(target, batch_chunk, 0)
            
            for i in range(len(image_chunks)):
                print('image_chunks:',i)
                
                img=image_chunks[i].to(device)
                lab=target_chunks[i].to(device)
                out=model(img)
                
                print("Chunks_Outputs: input size", img.size(),
                  "output_size", out.size())
                
                loss=criterion(out,lab)
                # print('{} chunk,loss:{}.'.format(i,loss))
                train_loss+=loss.item()
                
                loss = loss.float().mean().type_as(loss) / len(image_chunks)
                loss.backward()
                      
        else:
            image = image.to(device)
            target=target.to(device)
            
            output = model(image)
            loss=criterion(output,target)
            train_loss=loss.item()
            
            print("Outside: input size", image.size(),
                  "output_size", output.size())

    optimizer.step()  
    
    return train_loss
    

if __name__=="__main__":

    input_size = 5
    output_size = 3
    
    batch_size = 32
    data_size = 70
    
    batch_chunk=2
    gpu0_bsz=8
    
    epochs=2
    
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    
    # data
    rand_loader = DataLoader(dataset=RandomDataset(input_size, data_size),
                             batch_size=batch_size, shuffle=True)
                             
    # model                 
    model = Model(input_size, output_size)
    
    if torch.cuda.device_count() > 1:
        print("Let's use", torch.cuda.device_count(), "GPUs!")
        
        if gpu0_bsz >= 0:   
            model = BalancedDataParallel(gpu0_bsz // batch_chunk, model, dim=0)
        else:
            model = nn.DataParallel(model)
         
    model.to(device)
    
    # optimizer
    optimizer= torch.optim.SGD(model.parameters(), lr = 0.01, momentum = 0.9)
    
    # loss
    criterion=nn.CrossEntropyLoss()
    
    for epoch in range(epochs):
        print('Epoch:',epoch)
        train(rand_loader,model,optimizer,criterion)

Referencia: código oficial
de pytorch multi-gpu de entrenamiento paralelo
de transformador-XL
BalancedDataParallel Código de referencia
PyTorch-4 nn.DataParallel
Detalles del paralelismo de datos en la retropropagación de Pytorch: acumulación predeterminada al calcular gradientes

¡Todos son bienvenidos a dejar comentarios y críticas!

Supongo que te gusta

Origin blog.csdn.net/qq_44846512/article/details/115207166
Recomendado
Clasificación