Versão pytorch do mecanismo de carregamento do modelo pré-treinado

Ao fazer a detecção de alvos bidimensionais, modificaremos algumas partes da rede neural, como adicionar CBAM ou modificar FPN e assim por diante. No entanto, quando a rede modificada é treinada, o processo de carregamento dos pesos de pré-treinamento não relatará um erro, e mesmo o desempenho não aumentará, mas diminuirá após a modificação da rede. Os pontos de conhecimento contidos neste estão resumidos neste artigo.

O código para importar pesos pré-treinados na rede em pytorch é muito simples:

net = model()
net.to(device)
net.load_state_dict(torch.load('params.pth'))

dentro:

  • torch.load('params.pth') apenas carrega os parâmetros do modelo e não coloca os parâmetros na rede . Aqui, os parâmetros são carregados na memória na forma de pares chave-valor.
  • O método load_state_dict é colocar o dicionário de parâmetros na rede.

Quando tivermos feito modificações na rede, só precisamos fazer o seguinte para carregar os pesos pré-treinados anteriores na nova rede. No mmdetection ou em algum código-fonte aberto de Daniel, os pesos de pré-treinamento ainda podem ser carregados após modificar a rede pelos seguintes motivos

Acho que você gosta

Origin blog.csdn.net/qq_42308217/article/details/123140481
Recomendado
Clasificación