Modelo de treinamento multi-GPU PyTorch - um método para usar uma única GPU ou CPU para inferência

1 Descrição do problema

PyTorch fornece métodos de treinamento de rede multi-GPU muito convenientes: DataParallele DistributedDataParallel. Quando se trata de alguns modelos complexos, várias GPUs são basicamente usadas para treinar e salvar o modelo em paralelo. No entanto, no estágio de inferência, apenas uma única GPU ou CPU é frequentemente usada para execução. Neste momento, como carregar os pesos do modelo salvos em um ambiente multi-GPU para o modelo em um único ambiente de execução de GPU/CPU tornou-se uma questão importante.

Se você ignorar as questões ambientais e carregar diretamente, frequentemente ocorrerão dois tipos de problemas:

1 Ocorreu um erro: IndexError: índice da lista fora do intervalo

A razão para este erro é que os parâmetros do modelo existente são obtidos e salvos em várias GPUs, portanto, serão salvos na GPU correspondente por padrão durante a leitura. No entanto, há apenas uma GPU no ambiente de inferência atual, portanto, esses parâmetros que estavam originalmente em outras GPUs são Os parâmetros na GPU não conseguem encontrar o número da GPU para onde deveriam ir e ocorre um erro de overflow.A essência é que o número da GPU estourou.

2 Ocorreu um erro: Chave(s) ausente(s) em state_dict:

A razão para este erro é que devido aos diferentes ambientes de treinamento e inferência do modelo, alguns parâmetros são perdidos, então um erro é relatado. Algumas soluções atualmente disponíveis online são ignorar esses parâmetros ausentes, como usar o comando: model.load_state_dict(torch.load('model.pth'), strict=False)
para importar o modelo com sucesso. Este comando permite que o programa importe parâmetros do modelo sem relatar erros e aparentemente com sucesso. strict=FalseMas, na verdade, o significado deste comando é ignorar os parâmetros ausentes , definindo-os ao importar os parâmetros do modelo.Ou seja, os pesos do modelo onde os parâmetros estão faltando ainda estão no estado aleatório inicializado, o que equivale a nenhum treinamento e aprendizado, então não há raciocínio.com verificação! ! !

2. Método de salvamento de modelo

Não importa qual método seja usado para inferência, durante o treinamento, certifique-se de que o programa salve o modelo desta forma:

torch.save(model.state_dict(), "model.pth")

3 Carregando o modelo em uma única GPU

Carregue os arquivos de peso treinados em várias GPUs em uma única GPU:

# 1 加载模型
model = Model()
# 2 指定运行设备,这里为单块GPU
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# 3 将模型用DataParallel方法封装一次
model = torch.nn.DataParallel(model)
# 4 将模型读入到GPU设备上
model = model_E2E.to(device)
# 5 加载权重文件
model.load_state_dict(torch.load(weight_path, map_location=device))

Através do programa acima, os arquivos de peso treinados em múltiplas GPUs podem ser carregados no modelo em um único ambiente de GPU. Há dois pontos a serem observados aqui:

  • Durante o treinamento multi-GPU, o modelo usa o método DataParallelou . Essas duas ferramentas de paralelização modificarão a estrutura do modelo e encapsularão o modelo em um novo módulo, geralmente denominado: . Portanto, o modelo salvo no arquivo de pesos é uma estrutura encapsulada. Para poder carregar todos os parâmetros, o modelo de inferência precisa ser estruturalmente consistente com o modelo original de treinamento multi-GPU até a etapa 3.
    DistributedDataParallelmoduleDataParallel

  • Os parâmetros foram usados ​​ao carregar os parâmetros do modelo na etapa 5 map_location. Este parâmetro informa ao PyTorch em qual dispositivo o tensor deve ser colocado ao carregar o modelo. configuração map_location=device, qualquer dispositivo no qual o modelo foi originalmente treinado será agora colocado no dispositivo especificado device='cuda:0'.

4 Carregue o modelo na CPU

Carregue o modelo na CPU:

from collections import OrderedDict

# 1 加载模型
model = Model()
# 2 指定设备CPU
device = "cpu"
# 3 读取权重文件
state_dict = torch.load(weight_path, map_location=device)
# 4 剥除权重文件中的module层
if next(iter(state_dict)).startswith("module."):
    new_state_dict = OrderedDict()
    for k, v in state_dict.items():
        name = k[7:]  # remove `module.`
        new_state_dict[name] = v
    state_dict = new_state_dict
# 5 加载权重文件
model.load_state_dict(state_dict)
# 6 将模型载入到CPU
model = model.to(device)

A lógica de carregar o modelo na CPU é semelhante à da GPU, o núcleo é porque o modelo no arquivo de peso original é encapsulado module.Model, então esse shell precisa ser removido e, finalmente, o modelo é lido e carregado na CPU .

5 Resumo

Em tarefas de aprendizagem profunda, é muito comum que os ambientes de treinamento e inferência sejam diferentes. É muito importante ler corretamente o arquivo de peso da rede nos diferentes ambientes. Na operação real, você deve garantir que o arquivo de peso correto seja lido. Esta é a premissa mais básica para inferência! É melhor fazer alguns experimentos comparativos antes da inferência (por exemplo: selecionar uma parte dos dados, aplicar programas existentes para treinamento e inferência e comparar os efeitos dos dois) para garantir que os pesos corretos foram lidos.

Guess you like

Origin blog.csdn.net/qq_44949041/article/details/132734466