ChatGLM Efficient Tuning source code analysis src/utils/seq2seq.py (2)

class Seq2SeqTrainerForChatGLM(PeftTrainer):
    r"""
    Inherits PeftTrainer to compute generative metrics such as BLEU and ROUGE.
    """

    def save_predictions(
            self,
            predict_results: PredictionOutput,
            tokenizer: PreTrainedTokenizer
    ) -> None:
        r"""
        Saves model predictions to `output_dir`.

        A custom behavior that not contained in Seq2SeqTrainer.
        """
        if not self.is_world_process_zero():
            return

        preds = np.where(predict_results.predictions != IGNORE_INDEX, predict_results.predictions, self.tokenizer.pad_token_id)
        labels = np.where(predict_results.label_ids != IGNORE_INDEX, predict_results.label_ids, self.tokenizer.pad_token_id)

        preds = preds[:, labels.shape[1]:] # remove prompts
        preds = [tokenizer.decode(pred, skip_special_tokens=True).strip() for pred in preds]
        labels = [tokenizer.decode(label, skip_special_tokens=True).strip() for label in labels]

        output_prediction_file = os.path.join(self.args.output_dir, "generated_predictions.jsonl")
        logger.info(f"Saving prediction results to {output_prediction_file}")
        with open(output_prediction_file, "w", encoding="utf-8") as writer:
            res: List[str] = []
            for pred, label in zip(preds, labels):
                res.append(json.dumps({"label": label, "predict": pred}, ensure_ascii=False))
            writer.write("\n".join(res))

Here's a Python snippet that defines a Seq2SeqTrainerForChatGLMclass called , which it inherits fromPeftTrainer . This class specifically handles the training of sequence-to-sequence models , as well as the computation of generative metrics such as BLEU and ROUGE.

Next, we will explain this code line by line:

  1. class Seq2SeqTrainerForChatGLM(PeftTrainer):- This line defines a new class, Seq2SeqTrainerForChatGLM, that inherits from PeftTrainerclass.

  2. def save_predictions(self, predict_results: PredictionOutput, tokenizer: PreTrainedTokenizer) -> None:- This is a save_predictionsmethod that accepts two parameters, one is predict_results( prediction result, type is PredictionOutput) and the other is tokenizer(pretrained tokenizer, type isPreTrainedTokenizer ).

  3. if not self.is_world_process_zero(): return- This line checks if the current process is the main process (process number 0 ). If not, the method terminates.

  4. preds = np.where(predict_results.predictions != IGNORE_INDEX, predict_results.predictions, self.tokenizer.pad_token_id)- This line creates a new array of predictors, where the predictors for non-ignored indices are preserved, and the predictors for ignored indices are replaced with the IDs of the filled symbols.

  5. labels = np.where(predict_results.label_ids != IGNORE_INDEX, predict_results.label_ids, self.tokenizer.pad_token_id)- This line creates a new label array, where the label values ​​of non-ignored indices are kept , and the label values ​​of ignored indices are replaced with the IDs of the filled symbols.

  6. preds = preds[:, labels.shape[1]:]- This line removes the hints from the predictions, the hints are at the beginning of the predictions array .

  7. preds = [tokenizer.decode(pred, skip_special_tokens=True).strip() for pred in preds]- This line uses the tokenizer to decode the predictor, remove special symbols, and remove leading and trailing whitespace.

  8. labels = [tokenizer.decode(label, skip_special_tokens=True).strip() for label in labels]- This line decodes the tags using the tokenizer , removes special symbols, and removes leading and trailing whitespace.

  9. output_prediction_file = os.path.join(self.args.output_dir, "generated_predictions.jsonl")- This line defines the output file path of the prediction results.

  10. logger.info(f"Saving prediction results to {output_prediction_file}")- This line sends a message to the log indicating which file the predictions will be saved to.

  11. with open(output_prediction_file, "w", encoding="utf-8") as writer:- This line opens the forecast output file in write mode, ready to write forecast results.

  12. res: List[str] = []- This line initializes an empty list ready to collect predictions for each row.

  13. for pred, label in zip(preds, labels):- This line starts a loop over all predicted values ​​and labels.

  14. res.append(json.dumps({"label": label, "predict": pred}, ensure_ascii=False))- In the loop, this line converts each pair of predictions and labels in JSON format to a string, then adds to the result list.

  15. writer.write("\n".join(res))- This line concatenates the list of results into a string, with each result separated by a newline, and writes it to the file.

Guess you like

Origin blog.csdn.net/sinat_37574187/article/details/131459455