Calcule la cantidad de parámetros entrenables en el modelo: print_trainable_parameters [Consulte LoRA]

def print_trainable_parameters(model):
    """
    Prints the number of trainable parameters in the model.
    """
    trainable_params = 0
    all_param = 0
    for _, param in model.named_parameters():
        all_param += param.numel()
        if param.requires_grad:
            trainable_params += param.numel()
    print(f"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param}")

Ejecutarlo:

print_trainable_parameters(model)

La salida es la siguiente:

trainable params: 8388608 || all params: 6666862592 || trainable%: 0.12582542214183376




Guía introductoria al aprendizaje profundo en 2023 (12) - PEFT y LoRA_Jtag Agent's Blog - CSDN Blog

Supongo que te gusta

Origin blog.csdn.net/u013250861/article/details/131218289
Recomendado
Clasificación