Pytorch fijo parámetros-modelo de preentrenamiento y ajuste fino

Al revisar muchos blogs y foros, generalmente los parámetros de congelación incluyen dos pasos:

  1. Establezca el atributo del parámetro en False, es decir, require_grad = False
  2. Cuando se define el optimizador, los parámetros que no actualizan el gradiente se filtran, que suele ser el caso
optimizer.SGD(filter(lambda p: p.requires_grad, model.parameters()), lr=1e-3)

No entraré en detalles arriba, la mayor parte de Baidu es así.


Déjame hablar primero de mi tarea:

Tengo un modelo que consta de un codificador y un decodificador. Los parámetros del decodificador se fijan durante el pre-entrenamiento, y solo se entrenan los parámetros del codificador. Luego entrene todos los parámetros durante el ajuste fino.

problema:

De acuerdo con el método anterior, se informará un error de longitud inconsistente al recargar el modelo.

ValueError: el dictado de estado cargado contiene un grupo de parámetros que no coincide con el tamaño del grupo del optimizador

Después de depurar durante mucho tiempo, descubrí que el modelo que cargué solo guardaba los parámetros de la parte del codificador, pero el nuevo modelo son los parámetros del codificador y el decodificador. Por lo tanto, los parámetros entrenados previamente no se pueden cargar en el nuevo modelo.

Solución:

Solo establezca el atributo del parámetro en Verdadero / Falso, sin filtrar el parámetro en el optimizador, para que la longitud sea consistente.

Además, los parámetros fijos durante el proceso de preentrenamiento no se actualizan, y todos los parámetros se actualizan durante el ajuste fino, que solo cumple con nuestros requisitos.


Adjuntar mi proceso de ajuste:

  • Entrenamiento previo: solo modificar atributos, no filtrar parámetros
    for param in model.parameters():
        param.requires_grad = False
    for param in model.encoder.parameters():
        param.requires_grad = True

Genere los dos parámetros actualizados, puede encontrar que solo el codificador está actualizado y el decodificador no está actualizado.

  • afinar:
    for param in model.parameters():
        param.requires_grad = True

También envíe los parámetros actualizados dos veces, puede encontrar que los parámetros del decodificador también se han actualizado. ¡terminado!

Supongo que te gusta

Origin blog.csdn.net/Answer3664/article/details/104874243
Recomendado
Clasificación