ChatGLM Efficient Tuning源码解析 src/utils/peft_trainer.py

import os
import json
import time
import torch
from typing import Dict, Optional
from datetime import timedelta

from transformers import (
    Seq2SeqTrainer,
    TrainerCallback,
    TrainerControl,
    TrainerState,
    TrainingArguments
)

from transformers.trainer import TRAINING_ARGS_NAME
from transformers.modeling_utils import unwrap_model

from .config import FinetuningArguments

from .other import (
    get_logger,
    get_state_dict,
    load_trainable_params,
    load_valuehead_params,
    FINETUNING_ARGS_NAME,
    VALUE_HEAD_FILE_NAME
)


logger = get_logger(__name__)


class LogCallback(TrainerCallback):
    r"""
    TrainerCallback includes the state function during training, for more details refer to the TrainerCallback class.
    The on_log function primarily collects process parameters during training, such as training loss, learning rate,
    and training epochs, as well as progress parameters like the current percentage progress and estimated remaining
    time. Every time a log is triggered, a new record is appended to the file "messages.log" for dynamic visualization
    purposes.
    """

    def __init__(self):
        self.start_time = time.time()

    def on_log(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs) -> None:
        r"""
        Event called after logging the last logs.
        """
        if "loss" not in state.log_history[-1]:
            return
        cur_time = time.time()
        cur_steps = state.log_history[-1].get("step")
        elapsed_time = cur_time - self.start_time
        avg_time_per_step = elapsed_time / cur_steps if cur_steps != 0 else 0
        remaining_steps = state.max_steps - cur_steps
        remaining_time = remaining_steps * avg_time_per_step
        log_dict = {
            "current_steps": cur_steps,
            "total_steps": state.max_steps,
            "loss": state.log_history[-1].get("loss", None),
            "reward": state.log_history[-1].get("reward", None),
            "learning_rate": state.log_history[-1].get("learning_rate", None),
            "epoch": state.log_history[-1].get("epoch", None),
            "percentage": round(cur_steps / state.max_steps * 100, 2) if state.max_steps != 0 else 100,
            "elapsed_time": str(timedelta(seconds=int(elapsed_time))),
            "remaining_time": str(timedelta(seconds=int(remaining_time)))
        }
        os.makedirs(args.output_dir, exist_ok=True)
        with open(os.path.join(args.output_dir, "trainer_log.jsonl"), "a") as f:
            f.write(json.dumps(log_dict) + "\n")

以上是对这段代码的逐行解释。这段代码主要定义了一个名为LogCallback的类,该类继承自TrainerCallback。它重写了on_log方法,在每次训练过程打log时,它会计算一些关于训练过程的信息,然后将这些信息写入到一个JSONL(JSON Lines)格式的文件中,以便后续进行分析或可视化。

  1. from datetime import timedelta :这行代码是从datetime库中导入timedelta模块。 timedelta主要用于计算时间差,比如两个日期或者时间的差。

  2. from transformers import (Seq2SeqTrainer, TrainerCallback, TrainerControl, TrainerState, TrainingArguments):这行代码从transformers库中导入了几个模块,包括Seq2SeqTrainer(用于序列到序列模型的训练),TrainerCallback(用于定义训练过程中的回调函数),TrainerControl(用于控制训练过程),TrainerState(用于保存训练状态)和TrainingArguments(用于配置训练参数)。

  3. from transformers.trainer import TRAINING_ARGS_NAME:从transformers库的trainer模块中导入TRAINING_ARGS_NAME

  4. from transformers.modeling_utils import unwrap_model:从transformers库的modeling_utils模块中导入unwrap_model函数

  5. from .config import FinetuningArguments:从当前目录下的config模块中导入FinetuningArguments。

  6. from .other import (get_logger, get_state_dict, load_trainable_params, load_valuehead_params, FINETUNING_ARGS_NAME, VALUE_HEAD_FILE_NAME):从当前目录下的other模块中导入了一些函数和变量。

  7. logger = get_logger(__name__):创建一个日志记录器,__name__为模块名。

  8. class LogCallback(TrainerCallback)::定义了一个名为LogCallback的类,这个类继承了TrainerCallback类。

  9. def __init__(self)::定义LogCallback类的初始化函数。

  10. self.start_time = time.time():在初始化函数中,定义了一个成员变量start_time,记录了初始化的时间。

  11. def on_log(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs) -> None::定义了on_log方法,这个方法是在训练过程中每次打log时被调用。

  12. if "loss" not in state.log_history[-1]: return:如果最新的log中不包含"loss",则不执行下面的语句。

  13. cur_time = time.time():获取当前的时间。

  14. cur_steps = state.log_history[-1].get("step"):获取当前的步数。

  15. elapsed_time = cur_time - self.start_time:计算从开始到现在过去的时间。

  16. avg_time_per_step = elapsed_time / cur_steps if cur_steps != 0 else 0计算每一步的平均时间

  17. remaining_steps = state.max_steps - cur_steps:计算剩余的步数。

  18. remaining_time = remaining_steps * avg_time_per_step:计算剩余的时间。

  19. log_dict = {...}:创建一个字典,用于记录训练的状态。

  20. os.makedirs(args.output_dir, exist_ok=True):创建输出目录,如果目录已存在,则不报错。

  21. with open(os.path.join(args.output_dir, "trainer_log.jsonl"), "a") as f::在输出目录下创建一个名为"trainer_log.jsonl"的文件,以追加模式打开。

  22. f.write(json.dumps(log_dict) + "\n"):将log_dict转为json格式后写入到文件中,每条记录后面添加一个换行符。

  23. os.makedirs(args.output_dir, exist_ok=True):这行代码调用os库的makedirs函数来创建一个新目录。args.output_dir是新目录的路径,exist_ok=True意味着如果该目录已存在,那么不会引发任何错误,而是继续执行。

  24. with open(os.path.join(args.output_dir, "trainer_log.jsonl"), "a") as f::这行代码打开一个文件用于追加写入。文件路径是由args.output_dir(输出目录)和"trainer_log.jsonl"(文件名)拼接而成的。用上下文管理器with打开文件可以保证在操作完成后文件会被正确关闭。

  25. f.write(json.dumps(log_dict) + "\n"):这行代码将log_dict字典转换为JSON格式字符串,然后写入文件。并在每一条日志后面添加一个换行符,以便将各个日志条目分隔开。

class PeftTrainer(Seq2SeqTrainer):
    r"""
    Inherits Seq2SeqTrainer to support parameter-efficient checkpoints.
    """

    def __init__(self, finetuning_args: FinetuningArguments, **kwargs):
        super().__init__(**kwargs)
        self.finetuning_args = finetuning_args
        if self.is_world_process_zero() and os.path.exists(os.path.join(self.args.output_dir, "trainer_log.jsonl")):
            logger.warning("Previous log file in this folder will be deleted.")
            os.remove(os.path.join(self.args.output_dir, "trainer_log.jsonl"))

    def _save(self, output_dir: Optional[str] = None, state_dict: Optional[Dict[str, torch.Tensor]] = None) -> None:
        r"""
        Saves trainable parameters as model checkpoint.

        This function will only be executed at the process zero.

        Subclass and override to inject custom behavior. It should not be directly used by external scripts.
        """
        output_dir = output_dir if output_dir is not None else self.args.output_dir
        os.makedirs(output_dir, exist_ok=True)
        logger.info(f"Saving model checkpoint to {output_dir}")
        model = unwrap_model(self.model)

        if hasattr(model, "pretrained_model"): # for models with valuehead (currently using LoRA only)
            backbone_model = getattr(model, "pretrained_model")
            torch.save(get_state_dict(getattr(model, "v_head")), os.path.join(output_dir, VALUE_HEAD_FILE_NAME))
        else:
            backbone_model = model

        if self.finetuning_args.finetuning_type == "lora":
            backbone_model.save_pretrained(output_dir, state_dict=get_state_dict(backbone_model))
        else: # freeze/full tuning
            backbone_model.save_pretrained(
                output_dir,
                state_dict=get_state_dict(backbone_model),
                safe_serialization=self.args.save_safetensors
            )
            if self.tokenizer is not None:
                self.tokenizer.save_pretrained(output_dir)

        with open(os.path.join(output_dir, TRAINING_ARGS_NAME), "w", encoding="utf-8") as f:
            f.write(self.args.to_json_string() + "\n")
        self.finetuning_args.save_to_json(os.path.join(output_dir, FINETUNING_ARGS_NAME))

    def _load_best_model(self):
        r"""
        Loads trainable parameters from model checkpoint.

        Subclass and override to inject custom behavior. It should not be directly used by external scripts.
        """
        logger.info(f"Loading best model from {self.state.best_model_checkpoint} (score: {self.state.best_metric}).")

        model = unwrap_model(self.model)
        backbone_model = getattr(model, "pretrained_model") if hasattr(model, "pretrained_model") else model

        if self.finetuning_args.finetuning_type == "lora":
            backbone_model.load_adapter(self.state.best_model_checkpoint, getattr(backbone_model, "active_adapter"))
            if hasattr(model, "v_head") and load_valuehead_params(model, self.state.best_model_checkpoint):
                model.v_head.load_state_dict({
                    "summary.weight": getattr(model, "reward_head_weight"),
                    "summary.bias": getattr(model, "reward_head_bias")
                })
        else: # freeze/full-tuning or p_tuning
            load_trainable_params(backbone_model, self.state.best_model_checkpoint)

这段代码定义了一个名为PeftTrainer的类,这个类继承自Seq2SeqTrainer,用于支持参数有效的模型检查点保存和加载。接下来我将逐行解释这个类中的代码:

以上是_save方法的解释,它主要负责将模型的参数、分词器以及训练和微调参数保存到文件中。

以上是save_model方法的解释,它主要负责将模型保存到指定的目录。

以上是load_best_model方法的解释,它主要负责加载最佳模型。

综上所述,PeftTrainer类提供了模型保存、加载、以及从检查点加载模型等功能。

  1. class PeftTrainer(Seq2SeqTrainer)::定义了一个新的类PeftTrainer,这个类继承自Seq2SeqTrainer

  2. def __init__(self, finetuning_args: FinetuningArguments, **kwargs)::定义了这个类的初始化函数,这个函数接收一个FinetuningArguments型的参数finetuning_args任意数量的其他关键字参数

  3. super().__init__(**kwargs):调用父类的初始化函数

  4. self.finetuning_args = finetuning_args:将传入的finetuning_args参数赋值给self.finetuning_args

  5. if self.is_world_process_zero() and os.path.exists(os.path.join(self.args.output_dir, "trainer_log.jsonl"))::如果当前进程是主进程并且在输出目录下存在名为"trainer_log.jsonl"的文件。

  6. logger.warning("Previous log file in this folder will be deleted."):则在日志中记录一条警告信息。

  7. os.remove(os.path.join(self.args.output_dir, "trainer_log.jsonl")):并删除这个文件。

  8. def _save(self, output_dir: Optional[str] = None, state_dict: Optional[Dict[str, torch.Tensor]] = None) -> None::定义了一个名为_save的方法,这个方法接收一个可选的字符串参数output_dir和一个可选的字典参数state_dict

  9. output_dir = output_dir if output_dir is not None else self.args.output_dir:如果output_dir参数为None,则将self.args.output_dir赋值给output_dir

  10. os.makedirs(output_dir, exist_ok=True):创建输出目录,如果目录已存在,则不报错。

  11. logger.info(f"Saving model checkpoint to {output_dir}"):在日志中记录一条信息,表示正在将模型检查点保存到指定的目录

  12. model = unwrap_model(self.model):调用unwrap_model函数获取原始模型。

  13. if hasattr(model, "pretrained_model")::如果模型有"pretrained_model"这个属性。

  14. backbone_model = getattr(model, "pretrained_model"):则获取这个属性并赋值给backbone_model

  15. torch.save(get_state_dict(getattr(model, "v_head")), os.path.join(output_dir, VALUE_HEAD_FILE_NAME)):保存模型的v_head属性的状态字典到文件中。

  16. else::如果模型没有"pretrained_model"这个属性。

  17. backbone_model = model:则将原始模型赋值给backbone_model

  18. if self.finetuning_args.finetuning_type == "lora"::如果微调类型为"lora"。

  19. backbone_model.save_pretrained(output_dir, state_dict=get_state_dict(backbone_model)):则将backbone_model的状态字典保存为预训练模型

  20. else::如果微调类型不为"lora"

  21. def _load_best_model(self)::定义了一个名为_load_best_model的方法,这个方法没有参数。

  22. logger.info(f"Loading best model from {self.state.best_model_checkpoint} (score: {self.state.best_metric})."):在日志中记录一条信息,表示正在从指定的检查点加载最佳模型。

  23. model = unwrap_model(self.model):调用unwrap_model函数获取原始模型。

  24. backbone_model = getattr(model, "pretrained_model") if hasattr(model, "pretrained_model") else model:如果模型有"pretrained_model"这个属性,则获取这个属性并赋值给backbone_model,否则将原始模型赋值给backbone_model

  25. if self.finetuning_args.finetuning_type == "lora"::如果微调类型为"lora"。

  26. backbone_model.load_adapter(self.state.best_model_checkpoint, getattr(backbone_model, "active_adapter")):则从最佳模型的检查点加载适配器。

  27. if hasattr(model, "v_head") and load_valuehead_params(model, self.state.best_model_checkpoint)::如果模型有"v_head"这个属性并且可以从最佳模型的检查点加载v_head的参数。

  28. model.v_head.load_state_dict({"summary.weight": getattr(model, "reward_head_weight"), "summary.bias": getattr(model, "reward_head_bias")}):则从模型中获取reward_head_weightreward_head_bias属性,并将它们的值加载到v_head的状态字典中。

  29. else::如果微调类型不为"lora"。

  30. load_trainable_params(backbone_model, self.state.best_model_checkpoint):则从最佳模型的检查点加载可训练的参数。

  31. backbone_model.save_pretrained(output_dir, state_dict=get_state_dict(backbone_model), safe_serialization=self.args.save_safetensors):则以安全的方式将backbone_model的状态字典保存为预训练模型。

  32. if self.tokenizer is not None::如果有分词器。

  33. self.tokenizer.save_pretrained(output_dir):则将分词器保存为预训练模型。

  34. with open(os.path.join(output_dir, TRAINING_ARGS_NAME), "w", encoding="utf-8") as f::打开一个文件用于写入训练参数

  35. f.write(self.args.to_json_string() + "\n"):将训练参数转换为JSON格式的字符串并写入文件。

  36. self.finetuning_args.save_to_json(os.path.join(output_dir, FINETUNING_ARGS_NAME)):将微调参数保存为JSON格式的文件

    以上是_load_best_model方法的解释,它主要负责从检查点加载最佳模型的参数

    总的来说,PeftTrainer类是一个用于参数有效的模型微调的训练器

    接着,我们继续解析PeftTrainer类的代码:

  37. def _load_checkpoint(self, checkpoint_path: str, resume_training: bool = False) -> bool::定义了一个名为_load_checkpoint的方法,这个方法接收一个字符串参数checkpoint_path和一个布尔值参数resume_training

  38. logger.info(f"Loading model checkpoint from {checkpoint_path}"):在日志中记录一条信息,表示正在从指定的路径加载模型检查点。

  39. model = unwrap_model(self.model):调用unwrap_model函数获取原始模型。

  40. backbone_model = getattr(model, "pretrained_model") if hasattr(model, "pretrained_model") else model:如果模型有"pretrained_model"这个属性,则获取这个属性并赋值给backbone_model,否则将原始模型赋值给backbone_model

  41. if self.finetuning_args.finetuning_type == "lora"::如果微调类型为"lora"。

  42. backbone_model.load_adapter(checkpoint_path, getattr(backbone_model, "active_adapter")):则从检查点路径加载适配器。

  43. if hasattr(model, "v_head") and load_valuehead_params(model, checkpoint_path)::如果模型有"v_head"这个属性并且可以从检查点路径加载v_head的参数。

  44. model.v_head.load_state_dict({"summary.weight": getattr(model, "reward_head_weight"), "summary.bias": getattr(model, "reward_head_bias")}):则从模型中获取reward_head_weightreward_head_bias属性,并将它们的值加载到v_head的状态字典中。

  45. else::如果微调类型不为"lora"。

  46. load_trainable_params(backbone_model, checkpoint_path):则从检查点路径加载可训练的参数。

  47. if resume_training::如果需要继续训练。

  48. self.state = TrainerState.load_from_json(os.path.join(checkpoint_path, "trainer_state.json")):则从检查点路径中的"trainer_state.json"文件加载训练状态

    以上是_load_checkpoint方法的解释,它主要负责从指定的路径加载模型检查点以及训练状态。

  49. def save_model(self, output_dir: Optional[str] = None) -> None::定义了一个名为save_model的方法,这个方法接收一个可选的字符串参数output_dir

  50. output_dir = output_dir if output_dir is not None else self.args.output_dir:如果output_dir参数为None,则将self.args.output_dir赋值给output_dir

  51. self._save(output_dir):调用_save方法将模型保存到指定的目录。

  52. def load_best_model(self) -> None::定义了一个名为load_best_model的方法,这个方法没有参数。

  53. self._load_best_model():调用_load_best_model方法加载最佳模型。

接着,我们继续解析PeftTrainer类的代码:

def load_checkpoint(self, checkpoint_path: str, resume_training: bool = False) -> bool::定义了一个名为load_checkpoint的方法,这个方法接收一个字符串参数checkpoint_path和一个布尔值参数resume_training

self._load_checkpoint(checkpoint_path, resume_training):调用_load_checkpoint方法从指定路径加载检查点。

以上是load_checkpoint方法的解释,它主要负责从指定路径加载检查点。

以上是load_latest_checkpoint方法的解释,它主要负责从最新的检查点加载模型。

至此,PeftTrainer类的全部代码已经解析完成。这个类主要提供了模型训练、模型保存、从检查点加载模型以及加载最佳模型等功能。

  1. def load_latest_checkpoint(self, resume_training: bool = False) -> bool::定义了一个名为load_latest_checkpoint的方法,这个方法接收一个布尔值参数resume_training

  2. checkpoints = list(sorted(os.listdir(self.args.output_dir))):列出self.args.output_dir目录下的所有文件,并按照文件名进行排序。

  3. if not checkpoints::如果没有检查点。

  4. return False:则返回False

  5. latest_checkpoint = checkpoints[-1]:否则,取出最后一个检查点赋值给latest_checkpoint

  6. checkpoint_path = os.path.join(self.args.output_dir, latest_checkpoint):将最后一个检查点的路径赋值给checkpoint_path

  7. return self.load_checkpoint(checkpoint_path, resume_training):调用load_checkpoint方法从最后一个检查点的路径加载检查点,并返回加载结果。

QA

logger = get_logger(__name__) ,这个里面__name__是什么

__name__ 是 Python 中的一个预定义(内建)的变量,它代表当前模块的名字。在 Python 脚本中,如果该脚本被直接运行,那么 __name__ 的值会被设置为 '__main__';如果该脚本是被作为模块导入到其他脚本中,那么 __name__ 的值则会被设置为该模块的名字。

在你提到的代码 logger = get_logger(__name__) 中,get_logger 函数可能是用来创建或获取一个日志记录器(logger),而 __name__ 的值则用来设置该日志记录器的名字,这样在输出日志时可以知道是哪个模块产生的日志。这在调试、测试或者监控程序运行状态时非常有用。

 FINETUNING_ARGS_NAMEVALUE_HEAD_FILE_NAME 代表什么

FINETUNING_ARGS_NAMEVALUE_HEAD_FILE_NAME 可能是在代码中定义的一些常量。这些常量代表了一些特定的字符串,它们通常在代码的多个地方使用,这样如果需要改变这些字符串,只需要改变一次定义就可以了,而不需要在整个代码中找到所有的使用地方进行修改。

  1. FINETUNING_ARGS_NAME:这个变量可能是用来定义微调参数的文件名。在微调模型的过程中,可能会用到一些特定的参数,这些参数可能会被保存在一个文件中以方便之后的使用或者记录。

  2. VALUE_HEAD_FILE_NAME:这个变量可能是用来定义某个特定模型部分(如神经网络的一个头部,即网络的一部分)的文件名。在训练神经网络的过程中,可能会保存网络的某个部分的参数或者状态,这个文件名可能就是用来保存这些信息的文件的名字。

注意,这里的解释都是基于对变量名的猜测。实际的用途可能会根据代码的实际内容有所不同。如果你能提供更多的代码上下文,可能会得到更准确的解释。

猜你喜欢

转载自blog.csdn.net/sinat_37574187/article/details/131489790
今日推荐