ChatGLM Efficient Tuning源码解析train_sft.py

def main():

    # Prepare pretrained model and dataset
    model_args, data_args, training_args, finetuning_args = prepare_args(stage="sft")
    dataset = prepare_data(model_args, data_args)
    model, tokenizer = load_pretrained(model_args, finetuning_args, training_args.do_train, stage="sft")
    dataset = preprocess_data(dataset, tokenizer, data_args, training_args, stage="sft")
    data_collator = DataCollatorForChatGLM(
        tokenizer=tokenizer,
        model=model,
        ignore_pad_token_for_loss=(data_args.ignore_pad_token_for_loss and not training_args.predict_with_generate),
        use_v2=model_args.use_v2
    )

    # Override the decoding parameters of Seq2SeqTrainer
    training_args.generation_max_length = training_args.generation_max_length if \
                training_args.generation_max_length is not None else data_args.max_target_length
    training_args.generation_num_beams = data_args.eval_num_beams if \
                data_args.eval_num_beams is not None else training_args.generation_num_beams

    # Split the dataset
    if training_args.do_train:
        if data_args.dev_ratio > 1e-6:
            dataset = dataset.train_test_split(test_size=data_args.dev_ratio)
            trainer_kwargs = {"train_dataset": dataset["train"], "eval_dataset": dataset["test"]}
        else:
            trainer_kwargs = {"train_dataset": dataset}
    else: # do_eval or do_predict
        trainer_kwargs = {"eval_dataset": dataset}

    # Initialize our Trainer
    trainer = Seq2SeqTrainerForChatGLM(
        finetuning_args=finetuning_args,
        model=model,
        args=training_args,
        tokenizer=tokenizer,
        data_collator=data_collator,
        callbacks=[LogCallback()],
        compute_metrics=ComputeMetrics(tokenizer) if training_args.predict_with_generate else None,
        **trainer_kwargs
    )

    # Keyword arguments for `model.generate`
    gen_kwargs = {
        "do_sample": True,
        "top_p": 0.7,
        "max_new_tokens": data_args.max_target_length + 1,
        "temperature": 0.95,
        "logits_processor": get_logits_processor()
    }

    # Training
    if training_args.do_train:
        train_result = trainer.train()
        trainer.log_metrics("train", train_result.metrics)
        trainer.save_metrics("train", train_result.metrics)
        trainer.save_state()
        trainer.save_model()
        if trainer.is_world_process_zero() and model_args.plot_loss:
            plot_loss(training_args.output_dir, keys=["loss", "eval_loss"])

    # Evaluation
    if training_args.do_eval:
        metrics = trainer.evaluate(metric_key_prefix="eval", **gen_kwargs)
        if training_args.predict_with_generate: # eval_loss will be wrong if predict_with_generate is enabled
            metrics.pop("eval_loss", None)
        trainer.log_metrics("eval", metrics)
        trainer.save_metrics("eval", metrics)

    # Predict
    if training_args.do_predict:
        predict_results = trainer.predict(dataset, metric_key_prefix="predict", **gen_kwargs)
        if training_args.predict_with_generate: # predict_loss will be wrong if predict_with_generate is enabled
            predict_results.metrics.pop("predict_loss", None)
        trainer.log_metrics("predict", predict_results.metrics)
        trainer.save_metrics("predict", predict_results.metrics)
        trainer.save_predictions(predict_results, tokenizer)

这是一个复杂的深度学习训练流程,我会逐行解释这个过程。

  1. def main(): 定义了一个主函数 main

  2. model_args, data_args, training_args, finetuning_args = prepare_args(stage="sft") 这行从 prepare_args 函数获取四个参数:模型参数(model_args),数据参数(data_args),训练参数(training_args),微调参数(finetuning_args)。其中 stage="sft" 是传递给 prepare_args 的参数。

  3. dataset = prepare_data(model_args, data_args) 根据上述的模型和数据参数准备数据集。

  4. model, tokenizer = load_pretrained(model_args, finetuning_args, training_args.do_train, stage="sft") 这行加载了一个预训练的模型和对应的 tokenizer。

  5. dataset = preprocess_data(dataset, tokenizer, data_args, training_args, stage="sft") 对加载的数据集进行预处理,这包括使用 tokenizer 对文本数据进行编码等操作。

  6. data_collator = DataCollatorForChatGLM(...) 创建一个 DataCollator,用于在训练过程中处理/对齐批处理数据。

  7. 下面两行覆盖了 Seq2SeqTrainer 的一些解码参数。

  8. 接下来的 if 代码块会根据 training_args.do_train 决定是否分割数据集,以及如何分割。如果进行训练,那么会根据设定的比例分割训练和验证数据集。如果只进行评估或预测,那么全集将作为评估数据集。

  9. trainer = Seq2SeqTrainerForChatGLM(...) 通过上述的参数来创建一个 Seq2SeqTrainer。

  10. gen_kwargs = {...} 设定了一些用于 model.generate 的关键字参数,这些参数在生成文本时非常重要。

  11. 接下来的 if 代码块中,如果设定了进行训练 (training_args.do_train),那么模型就会被训练,训练的结果会被保存和记录,训练结束后模型也会被保存。

  12. 下一个 if 代码块是针对模型评估的。如果设定了进行评估 (training_args.do_eval),那么就会对模型进行评估,评估的结果会被保存和记录。

  13. 最后一个 if 代码块是针对预测的。如果设定了进行预测 (training_args.do_predict),那么模型就会进行预测,预测的结果会被保存和记录。

总的来说,这是一个典型的深度学习模型的训练,评估和预测的流程。在这个过程中,数据会被加载,预处理,然后用于训练或者评估模型。这个过程中的各种参数都可以通过 prepare_args数来设定。

ignore_pad_token_for_loss=(data_args.ignore_pad_token_for_loss and not training_args.predict_with_generate) 这行代码是在设定 DataCollatorForChatGLM 的参数。

在这个参数中,ignore_pad_token_for_loss 控制了是否在计算损失函数时忽略填充(padding)的标记。

在自然语言处理中,我们经常需要将不同长度的序列填充到相同的长度以便能在一起处理。我们通过添加一些特殊的“填充”标记(pad tokens)来实现这一点。但是在计算模型的损失时,我们通常不希望这些填充标记影响损失的计算,因此我们会忽略它们。

data_args.ignore_pad_token_for_loss 是一个布尔值,如果为 True,那么在计算损失时会忽略填充标记。training_args.predict_with_generate 也是一个布尔值,它表示是否在预测阶段使用生成(generate)方法。如果 predict_with_generateTrue,那么我们会在预测阶段生成新的序列,这时候我们不希望忽略填充标记

因此,data_args.ignore_pad_token_for_loss and not training_args.predict_with_generate 的结果将决定在计算损失时是否忽略填充标记。当我们需要忽略填充标记,并且不是在预测阶段使用生成方法时,我们才会在计算损失时忽略填充标记。

# Keyword arguments for `model.generate`
    gen_kwargs = {
        "do_sample": True,
        "top_p": 0.7,
        "max_new_tokens": data_args.max_target_length + 1,
        "temperature": 0.95,
        "logits_processor": get_logits_processor()
    }

这里面的max_new_tokens 为什么是.max_target_length+1?

在代码中,max_new_tokens 参数被设置为 data_args.max_target_length + 1。这是因为在许多自然语言处理(NLP)的任务中,包括序列生成任务,我们经常会在每个生成的序列的末尾添加一个特殊的结束符号(<eos> 或者 <pad>),来表示该序列的结束。

因此,max_new_tokens 实际上表示的是生成的最大序列长度,包括了这个特殊的结束符号。所以我们通常会将 max_new_tokens 设置为 max_target_length + 1,其中 max_target_length 表示我们期望生成的序列的最大长度(不包括结束符号),+1 则是为了留出结束符号的位置。

但是具体的设置可能会依赖于你使用的具体模型和任务,以及你的 tokenizer 如何处理这个结束符号。在某些情况下,你可能需要将 max_new_tokens 设置为 max_target_length,而不是 max_target_length + 1

def _mp_fn(index):是什么意思 xla_pawn是Pytorch用来支持google TPU的扩展:

def _mp_fn(index):
    # For xla_spawn (TPUs)
    main()

这段代码定义了一个函数 _mp_fn,它接受一个参数 index(这里并未使用),然后调用 main 函数。

这种模式在使用 xla_spawn 时会用到。xla_spawn 是 PyTorch XLA(PyTorch 的扩展,用于支持使用 Google 的 TPU)的一个脚本,它允许用户在多个 TPU 核上运行同样的代码。xla_spawn 使用 multiprocessing (多进程) 来在每个 TPU 核上启动一个 Python 进程。

当你在 xla_spawn 下运行代码时,你需要一个函数(在这里是 _mp_fn),它接受一个索引参数(在这里是 index),并且在函数内部调用你的主程序(在这里是 main 函数)。index 参数代表的是当前的进程索引(或者说,当前的 TPU 核的索引)。你的主程序在每个 TPU 核上都会运行。

在这个具体的例子中,_mp_fn 并没有使用 index 参数。这意味着 main 函数在所有 TPU 核上的行为都是一样的。在一些更复杂的用例中,你可能会根据 index 参数来改变每个进程的行为,例如,你可能会让不同的进程处理数据的不同部分。

这段代码是使用一个名为Seq2SeqTrainerForChatGLM的训练器进行序列到序列模型的训练、评估和预测的过程。下面我会逐行解释这段代码:

  1. trainer = Seq2SeqTrainerForChatGLM(...):这里初始化一个Seq2SeqTrainerForChatGLM对象。这是一个专门为了处理序列到序列任务(比如对话系统,翻译等)的训练器。参数包括:微调参数(finetuning_args)、模型(model)、训练参数(training_args)、分词器(tokenizer)、数据整理者(data_collator)、回调函数(callbacks)、计算评价指标的方法(compute_metrics),以及一些额外的训练器参数(trainer_kwargs)

  2. gen_kwargs = {...}:定义一个字典,这个字典包含了一些关键词参数,用于调控模型生成文本时的行为,比如"do_sample"决定是否采样,"top_p"决定在生成时采样的集中程度,"max_new_tokens"是生成的最大长度,"temperature"用于调整采样过程中的随机性等。

  3. if training_args.do_train::这个判断句用于确定是否进行训练阶段。如果需要训练,则执行下面的代码。

  4. train_result = trainer.train():使用之前定义的训练器开始训练模型,返回训练结果。

  5. trainer.log_metrics(...)trainer.save_metrics(...):这两行代码用于记录训练过程中的各种度量(例如loss)并保存。

  6. trainer.save_state()trainer.save_model():这两行代码保存训练器的状态和模型。

  7. if trainer.is_world_process_zero() and model_args.plot_loss::如果是在多进程环境中的主进程,并且设置了需要绘制loss图,那么就绘制loss变化的图像。

  8. if training_args.do_eval::这个判断句用于确定是否进行评估阶段。如果需要进行评估,则执行下面的代码。

  9. metrics = trainer.evaluate(metric_key_prefix="eval", **gen_kwargs):使用之前定义的训练器开始评估模型,返回评估的度量。

  10. if training_args.predict_with_generate::如果启用了预测生成,那么"eval_loss"将是错误的,需要从度量中移除

  11. trainer.log_metrics("eval", metrics)trainer.save_metrics("eval", metrics):这两行代码用于记录评估过程中的各种度量并保存。

  12. if training_args.do_predict::这个判断句用于确定是否进行预测阶段。如果需要进行预测,则执行下面的代码。

  13. predict_results = trainer.predict(dataset, metric_key_prefix="predict", **gen_kwargs):使用之前定义的训练器开始预测,返回预测的结果。

  14. if training_args.predict_with_generate::如果启用了预测生成,那么"predict_loss"将是错误的,需要从度量中移除。

  15. trainer.log_metrics("predict", predict_results.metrics)trainer.save_metrics("predict", predict_results.metrics):这两行代码用于记录预测过程中的各种度量并保存。

  16. trainer.save_predictions(predict_results, tokenizer):保存预测的结果。

以上就是对这段代码的逐行解释。

  • trainer.train():这个函数实现了模型的训练过程。具体来说,它会进行一个训练循环,对每一个batch的数据进行前向传播,计算损失,然后反向传播优化模型参数。

  • trainer.evaluate():这个函数实现了模型的评估过程。它会在给定的验证集上评估模型的性能,并返回一些度量结果(如准确率,F1分数等)。

  • trainer.predict():这个函数实现了模型的预测过程。它会在给定的测试集上预测模型的输出,然后可能会根据具体任务对这些输出进行进一步的处理。

这段代码还展示了一些使用训练器对象的常用模式,如保存和加载模型、记录和保存度量结果、以及处理并行训练和预测。总的来说,这是一个非常通用的模型训练、评估和预测的框架,可以用于许多不同的序列到序列的任务。

training_args.generation_max_length

  1. training_args.generation_max_length:这一行代码设置生成文本的最大长度。如果training_args.generation_max_length参数没有被设定(即它的值是None),则会用data_args.max_target_length作为默认值。否则,就会直接使用training_args.generation_max_length。所以这个操作实质上是在给generation_max_length提供一个默认值。

  2. training_args.generation_num_beams:这一行代码设置束搜索(beam search)的束宽。如果data_args.eval_num_beams参数没有被设定(即它的值是None),则会用training_args.generation_num_beams作为默认值。否则,就会直接使用data_args.eval_num_beams。所以这个操作实质上是在给generation_num_beams提供一个默认值。

这两个参数都是在生成模型输出的时候非常重要的。generation_max_length确定了生成文本的最大长度,generation_num_beams制了搜索的宽度,影响生成文本的质量和生成速度。当束宽(beam width)增加,搜索空间也就相应扩大,意味着生成的结果可能更优,但同时也会增加计算量和计算时间。

猜你喜欢

转载自blog.csdn.net/sinat_37574187/article/details/131458667