ChatGLM Efficient Tuning source code analysis train_sft.py

def main():

    # Prepare pretrained model and dataset
    model_args, data_args, training_args, finetuning_args = prepare_args(stage="sft")
    dataset = prepare_data(model_args, data_args)
    model, tokenizer = load_pretrained(model_args, finetuning_args, training_args.do_train, stage="sft")
    dataset = preprocess_data(dataset, tokenizer, data_args, training_args, stage="sft")
    data_collator = DataCollatorForChatGLM(
        tokenizer=tokenizer,
        model=model,
        ignore_pad_token_for_loss=(data_args.ignore_pad_token_for_loss and not training_args.predict_with_generate),
        use_v2=model_args.use_v2
    )

    # Override the decoding parameters of Seq2SeqTrainer
    training_args.generation_max_length = training_args.generation_max_length if \
                training_args.generation_max_length is not None else data_args.max_target_length
    training_args.generation_num_beams = data_args.eval_num_beams if \
                data_args.eval_num_beams is not None else training_args.generation_num_beams

    # Split the dataset
    if training_args.do_train:
        if data_args.dev_ratio > 1e-6:
            dataset = dataset.train_test_split(test_size=data_args.dev_ratio)
            trainer_kwargs = {"train_dataset": dataset["train"], "eval_dataset": dataset["test"]}
        else:
            trainer_kwargs = {"train_dataset": dataset}
    else: # do_eval or do_predict
        trainer_kwargs = {"eval_dataset": dataset}

    # Initialize our Trainer
    trainer = Seq2SeqTrainerForChatGLM(
        finetuning_args=finetuning_args,
        model=model,
        args=training_args,
        tokenizer=tokenizer,
        data_collator=data_collator,
        callbacks=[LogCallback()],
        compute_metrics=ComputeMetrics(tokenizer) if training_args.predict_with_generate else None,
        **trainer_kwargs
    )

    # Keyword arguments for `model.generate`
    gen_kwargs = {
        "do_sample": True,
        "top_p": 0.7,
        "max_new_tokens": data_args.max_target_length + 1,
        "temperature": 0.95,
        "logits_processor": get_logits_processor()
    }

    # Training
    if training_args.do_train:
        train_result = trainer.train()
        trainer.log_metrics("train", train_result.metrics)
        trainer.save_metrics("train", train_result.metrics)
        trainer.save_state()
        trainer.save_model()
        if trainer.is_world_process_zero() and model_args.plot_loss:
            plot_loss(training_args.output_dir, keys=["loss", "eval_loss"])

    # Evaluation
    if training_args.do_eval:
        metrics = trainer.evaluate(metric_key_prefix="eval", **gen_kwargs)
        if training_args.predict_with_generate: # eval_loss will be wrong if predict_with_generate is enabled
            metrics.pop("eval_loss", None)
        trainer.log_metrics("eval", metrics)
        trainer.save_metrics("eval", metrics)

    # Predict
    if training_args.do_predict:
        predict_results = trainer.predict(dataset, metric_key_prefix="predict", **gen_kwargs)
        if training_args.predict_with_generate: # predict_loss will be wrong if predict_with_generate is enabled
            predict_results.metrics.pop("predict_loss", None)
        trainer.log_metrics("predict", predict_results.metrics)
        trainer.save_metrics("predict", predict_results.metrics)
        trainer.save_predictions(predict_results, tokenizer)

This is a complex deep learning training pipeline, and I will explain the process line by line.

  1. def main():A main function is defined main.

  2. model_args, data_args, training_args, finetuning_args = prepare_args(stage="sft")This line prepare_argstakes four arguments from the function : model arguments (model_args), data arguments (data_args), training arguments (training_args), and finetuning arguments (finetuning_args). where stage="sft"is prepare_argsthe parameter passed to .

  3. dataset = prepare_data(model_args, data_args)Prepare the dataset according to the model and data parameters .

  4. model, tokenizer = load_pretrained(model_args, finetuning_args, training_args.do_train, stage="sft")This line loads a pretrained model and corresponding tokenizer.

  5. dataset = preprocess_data(dataset, tokenizer, data_args, training_args, stage="sft")Preprocess the loaded dataset, which includes encoding text data using a tokenizer, etc.

  6. data_collator = DataCollatorForChatGLM(...)Create a DataCollator for processing/aligning batched data during training.

  7. The next two lines override some decoding parameters for . Seq2SeqTrainer

  8. The next ifblock of code will training_args.do_traindecide whether to split the data set and how to split it according to . If training, the training and validation datasets are split according to the set ratio. If evaluating or predicting only, then the full set will be used as the evaluation dataset.

  9. trainer = Seq2SeqTrainerForChatGLM(...)Create a Seq2SeqTrainer with the above parameters.

  10. gen_kwargs = {...}Sets some model.generatekeyword arguments for that are important when generating text.

  11. In the following code block, if training ( ) is ifset , the model will be trained, the training results will be saved and recorded, and the model will be saved after the training.training_args.do_train

  12. The next ifcode block is for model evaluation. If evaluate( ) is set , then the model will be evaluated, and the results of the evaluation will be saved and recorded.training_args.do_eval

  13. The last ifcode block is for prediction. If it is set to predict ( training_args.do_predict), then the model will make a prediction, and the result of the prediction will be saved and recorded.

In general, this is a typical deep learning model training, evaluation and prediction process. During this process, data is loaded, preprocessed, and used to train or evaluate models. Various parameters in this process can be set through prepare_argsthe function

ignore_pad_token_for_loss=(data_args.ignore_pad_token_for_loss and not training_args.predict_with_generate)This line of code is setting DataCollatorForChatGLMthe parameters.

In this parameter, ignore_pad_token_for_losscontrols whether to ignore the padding flag when calculating the loss function .

In natural language processing, we often need to pad sequences of different lengths to the same length so that they can be processed together. We do this by adding some special "pad" tokens. But when calculating the loss of the model, we usually don't want these padding flags to affect the calculation of the loss, so we ignore them.

data_args.ignore_pad_token_for_lossis a boolean, if true Truethen padding flags are ignored when computing the loss. training_args.predict_with_generateAlso a boolean that indicates whether to use the generate method during the prediction phase. If predict_with_generate, Truethen we will generate a new sequence during the prediction phase, and we don't want to ignore the padding markers at this time .

Therefore, the result of will determine whether padding markers are ignored when computing the loss . We ignore padding tokens when computing loss when we need to omit padding tokens and not using a generative method during the prediction phase.data_args.ignore_pad_token_for_loss and not training_args.predict_with_generate

# Keyword arguments for `model.generate`
    gen_kwargs = {
        "do_sample": True,
        "top_p": 0.7,
        "max_new_tokens": data_args.max_target_length + 1,
        "temperature": 0.95,
        "logits_processor": get_logits_processor()
    }

Why is the max_new_tokens here .max_target_length+1?

In code, max_new_tokensthe parameter is set to data_args.max_target_length + 1. This is because in many natural language processing (NLP) tasks, including sequence generation tasks, we often add a special end symbol (such as or ) at the end of each generated sequence to indicate the end of the <eos>sequence <pad>.

Therefore, max_new_tokenswhat is actually represented is the maximum sequence length generated, including this special end symbol. So we usually max_new_tokensset tomax_target_length + 1 , where max_target_lengthrepresents the maximum length of the sequence we expect to generate (excluding the end symbol), +1in order to reserve the position of the end symbol.

But the specific settings may depend on the specific model and task you use, and tokenizerhow you handle this end symbol. In some cases, you may need to max_new_tokensset to max_target_length, instead of max_target_length + 1.

def _mp_fn(index): What does it mean xla_pawn is an extension used by Pytorch to support google TPU:

def _mp_fn(index):
    # For xla_spawn (TPUs)
    main()

This code defines a function _mp_fnthat takes a parameter index(not used here) and then calls mainthe function.

This mode xla_spawnis used when using . xla_spawnis a script for PyTorch XLA (an extension of PyTorch to support the use of Google's TPU ), which allows users to run the same code on multiple TPU cores. xla_spawnUse multiprocessing to start a Python process on each TPU core.

When you run your code xla_spawnunder , you need a function (here function_mp_fn ) that takes an index parameter (here function index), and inside the function calls your main program (here mainfunction ). indexThe parameter represents the current process index (or in other words, the index of the current TPU core). Your main program runs on each TPU core.

In this specific example, the parameter _mp_fnis not used index. This means that mainthe function behaves the same on all TPU cores. In some more complex use cases, you might indexchange the behavior of each process based on a parameter, for example, you might have different processes handle different parts of the data.

This code is the process of using a trainer called Seq2SeqTrainerForChatGLM for sequence-to-sequence model training, evaluation and prediction. Below I explain this code line by line:

  1. trainer = Seq2SeqTrainerForChatGLM(...): A Seq2SeqTrainerForChatGLM object is initialized here. This is a trainer specially designed to handle sequence-to-sequence tasks (such as dialogue systems, translation, etc.). Parameters include: finetuning parameters (finetuning_args ), model (model), training parameters (training_args), tokenizer (tokenizer), data collator (data_collator ), callback function (callbacks), method of calculating evaluation indicators (compute_metrics), and some Additional trainer parameters (trainer_kwargs) .

  2. gen_kwargs = {...}: Define a dictionary, this dictionary contains some keyword parameters, which are used to control the behavior of the model when generating text , such as " do_sample" determines whether to sample , "top_p" determines the concentration of sampling during generation , "max_new_tokens" is generated The maximum length of , "temperature" is used to adjust the randomness in the sampling process, etc.

  3. if training_args.do_train:: This judgment sentence is used to determine whether to carry out the training phase. If training is required, execute the following code.

  4. train_result = trainer.train(): Use the previously defined trainer to start training the model and return the training result.

  5. trainer.log_metrics(...)and trainer.save_metrics(...): These two lines of code are used to record and save various metrics (such as loss) during training.

  6. trainer.save_state()and trainer.save_model(): These two lines of code save the state of the trainer and the model.

  7. if trainer.is_world_process_zero() and model_args.plot_loss:: If it is the main process in a multi-process environment, and it is set to draw a loss graph, then draw the image of the loss change.

  8. if training_args.do_eval:: This judgment sentence is used to determine whether to perform the evaluation phase. If evaluation is required, the code below is executed.

  9. metrics = trainer.evaluate(metric_key_prefix="eval", **gen_kwargs): Start evaluating the model using the previously defined trainer , returning the evaluated metrics.

  10. if training_args.predict_with_generate:: If prediction generation is enabled, then "eval_loss" will be wrong and needs to be removed from the metric .

  11. trainer.log_metrics("eval", metrics)and trainer.save_metrics("eval", metrics): These two lines of code are used to record and save various metrics during the evaluation process.

  12. if training_args.do_predict:: This judgment sentence is used to determine whether to perform the prediction phase. If prediction is required, execute the code below.

  13. predict_results = trainer.predict(dataset, metric_key_prefix="predict", **gen_kwargs): Use the previously defined trainer to start prediction and return the predicted result.

  14. if training_args.predict_with_generate:: If prediction generation is enabled, then "predict_loss" will be wrong and needs to be removed from the metric.

  15. trainer.log_metrics("predict", predict_results.metrics)and trainer.save_metrics("predict", predict_results.metrics): These two lines of code are used to record and save various metrics during the prediction process.

  16. trainer.save_predictions(predict_results, tokenizer): Save the predicted results.

The above is a line-by-line explanation of this code.

  • trainer.train(): This function implements the training process of the model. Specifically, it performs a training loop, forward-propagates each batch of data, calculates the loss, and then back-propagates to optimize the model parameters.

  • trainer.evaluate(): This function implements the model evaluation process. It evaluates the performance of the model on a given validation set and returns some metrics (such as accuracy, F1 score, etc.).

  • trainer.predict(): This function implements the prediction process of the model. It predicts the output of the model on a given test set, which may then be further processed depending on the specific task.

This code also shows some common patterns for working with trainer objects, such as saving and loading models, recording and saving metrics, and handling parallel training and prediction. Overall, this is a very general framework for model training, evaluation, and prediction that can be used for many different sequence-to-sequence tasks.

training_args.generation_max_length

  1. training_args.generation_max_length: This line of code sets the maximum length of the generated text . If training_args.generation_max_lengththe parameter is not set (i.e. its value is None), it will be used data_args.max_target_lengthas the default value. Otherwise, it is used directly training_args.generation_max_length. So this operation is essentially providing generation_max_lengtha default value.

  2. training_args.generation_num_beams: This line of code sets the beam width for beam search . If data_args.eval_num_beamsthe parameter is not set (i.e. its value is None), it will be used training_args.generation_num_beamsas the default value. Otherwise, it is used directly data_args.eval_num_beams. So this operation is essentially providing generation_num_beamsa default value.

Both of these parameters are very important when generating model output. generation_max_lengthIt determines the maximum length of the generated text, generation_num_beamscontrols the width of the search, and affects the quality and generation speed of the generated text. When the beam width increases, the search space will expand accordingly, which means that the generated results may be better , but at the same time it will increase the amount of calculation and calculation time.

Guess you like

Origin blog.csdn.net/sinat_37574187/article/details/131458667