Al revisar muchos blogs y foros, generalmente los parámetros de congelación incluyen dos pasos:
- Establezca el atributo del parámetro en False, es decir, require_grad = False
- Cuando se define el optimizador, los parámetros que no actualizan el gradiente se filtran, que suele ser el caso
optimizer.SGD(filter(lambda p: p.requires_grad, model.parameters()), lr=1e-3)
No entraré en detalles arriba, la mayor parte de Baidu es así.
Déjame hablar primero de mi tarea:
Tengo un modelo que consta de un codificador y un decodificador. Los parámetros del decodificador se fijan durante el pre-entrenamiento, y solo se entrenan los parámetros del codificador. Luego entrene todos los parámetros durante el ajuste fino.
problema:
De acuerdo con el método anterior, se informará un error de longitud inconsistente al recargar el modelo.
ValueError: el dictado de estado cargado contiene un grupo de parámetros que no coincide con el tamaño del grupo del optimizador
Después de depurar durante mucho tiempo, descubrí que el modelo que cargué solo guardaba los parámetros de la parte del codificador, pero el nuevo modelo son los parámetros del codificador y el decodificador. Por lo tanto, los parámetros entrenados previamente no se pueden cargar en el nuevo modelo.
Solución:
Solo establezca el atributo del parámetro en Verdadero / Falso, sin filtrar el parámetro en el optimizador, para que la longitud sea consistente.
Además, los parámetros fijos durante el proceso de preentrenamiento no se actualizan, y todos los parámetros se actualizan durante el ajuste fino, que solo cumple con nuestros requisitos.
Adjuntar mi proceso de ajuste:
- Entrenamiento previo: solo modificar atributos, no filtrar parámetros
for param in model.parameters():
param.requires_grad = False
for param in model.encoder.parameters():
param.requires_grad = True
Genere los dos parámetros actualizados, puede encontrar que solo el codificador está actualizado y el decodificador no está actualizado.
- afinar:
for param in model.parameters():
param.requires_grad = True
También envíe los parámetros actualizados dos veces, puede encontrar que los parámetros del decodificador también se han actualizado. ¡terminado!