ChatGLM source code analysis main.py

import logging
import os
import sys
import json

import numpy as np
from datasets import load_dataset
import jieba 
from rouge_chinese import Rouge
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
import torch

import transformers
from transformers import (
    AutoConfig,
    AutoModel,
    AutoTokenizer,
    DataCollatorForSeq2Seq,
    HfArgumentParser,
    Seq2SeqTrainingArguments,
    set_seed,
)
from trainer_seq2seq import Seq2SeqTrainer

from arguments import ModelArguments, DataTrainingArguments

logger = logging.getLogger(__name__)

def main():

    parser = HfArgumentParser((ModelArguments, DataTrainingArguments, Seq2SeqTrainingArguments))
    if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
        # If we pass only one argument to the script and it's the path to a json file,
        # let's parse it to get our arguments.
        model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
    else:
        model_args, data_args, training_args = parser.parse_args_into_dataclasses()

    # Setup logging
    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        handlers=[logging.StreamHandler(sys.stdout)],
    )

    if training_args.should_log:
        # The default of training_args.log_level is passive, so we set log level at info here to have that default.
        transformers.utils.logging.set_verbosity_info()

    log_level = training_args.get_process_log_level()
    logger.setLevel(log_level)
    # datasets.utils.logging.set_verbosity(log_level)
    transformers.utils.logging.set_verbosity(log_level)
    transformers.utils.logging.enable_default_handler()
    transformers.utils.logging.enable_explicit_format()

    # Log on each process the small summary:
    logger.warning(
        f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
        + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
    )
    logger.info(f"Training/evaluation parameters {training_args}")

    # Set seed before initializing model.
    set_seed(training_args.seed)

    # Load dataset
    data_files = {}
    if data_args.train_file is not None:
        data_files["train"] = data_args.train_file
        extension = data_args.train_file.split(".")[-1]
    if data_args.validation_file is not None:
        data_files["validation"] = data_args.validation_file
        extension = data_args.validation_file.split(".")[-1]
    if data_args.test_file is not None:
        data_files["test"] = data_args.test_file
        extension = data_args.test_file.split(".")[-1]

    raw_datasets = load_dataset(
        extension,
        data_files=data_files,
        cache_dir=model_args.cache_dir,
        use_auth_token=True if model_args.use_auth_token else None,
    )

    # Load pretrained model and tokenizer
    config = AutoConfig.from_pretrained(model_args.model_name_or_path, trust_remote_code=True)
    config.pre_seq_len = model_args.pre_seq_len
    config.prefix_projection = model_args.prefix_projection

    tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path, trust_remote_code=True)

    if model_args.ptuning_checkpoint is not None:
        # Evaluation
        # Loading extra state dict of prefix encoder
        model = AutoModel.from_pretrained(model_args.model_name_or_path, config=config, trust_remote_code=True)
        prefix_state_dict = torch.load(os.path.join(model_args.ptuning_checkpoint, "pytorch_model.bin"))
        new_prefix_state_dict = {}
        for k, v in prefix_state_dict.items():
            if k.startswith("transformer.prefix_encoder."):
                new_prefix_state_dict[k[len("transformer.prefix_encoder."):]] = v
        model.transformer.prefix_encoder.load_state_dict(new_prefix_state_dict)
    else:
        model = AutoModel.from_pretrained(model_args.model_name_or_path, config=config, trust_remote_code=True)

    if model_args.quantization_bit is not None:
        print(f"Quantized to {model_args.quantization_bit} bit")
        model = model.quantize(model_args.quantization_bit)
    if model_args.pre_seq_len is not None:
        # P-tuning v2
        model = model.half()
        model.transformer.prefix_encoder.float()
    else:
        # Finetune
        model = model.float()

    prefix = data_args.source_prefix if data_args.source_prefix is not None else ""

    # Preprocessing the datasets.
    # We need to tokenize inputs and targets.
    if training_args.do_train:
        column_names = raw_datasets["train"].column_names
    elif training_args.do_eval:
        column_names = raw_datasets["validation"].column_names
    elif training_args.do_predict:
        column_names = raw_datasets["test"].column_names
    else:
        logger.info("There is nothing to do. Please pass `do_train`, `do_eval` and/or `do_predict`.")
        return

    # Get the column names for input/target.
    prompt_column = data_args.prompt_column
    response_column = data_args.response_column
    history_column = data_args.history_column
    
    # Temporarily set max_target_length for training.
    max_target_length = data_args.max_target_length

    def preprocess_function_eval(examples):
        inputs, targets = [], []
        for i in range(len(examples[prompt_column])):
            if examples[prompt_column][i] and examples[response_column][i]:
                query = examples[prompt_column][i]
                if history_column is None or len(examples[history_column][i]) == 0:
                    prompt = query
                else:
                    prompt = ""
                    history = examples[history_column][i]
                    for turn_idx, (old_query, response) in enumerate(history):
                        prompt += "[Round {}]\n问:{}\n答:{}\n".format(turn_idx, old_query, response)
                    prompt += "[Round {}]\n问:{}\n答:".format(len(history), query)
                inputs.append(prompt)
                targets.append(examples[response_column][i])

        inputs = [prefix + inp for inp in inputs]
        model_inputs = tokenizer(inputs, max_length=data_args.max_source_length, truncation=True, padding=True)
        labels = tokenizer(text_target=targets, max_length=max_target_length, truncation=True)

        if data_args.ignore_pad_token_for_loss:
            labels["input_ids"] = [
                [(l if l != tokenizer.pad_token_id else -100) for l in label] for label in labels["input_ids"]
            ]
        model_inputs["labels"] = labels["input_ids"]

        return model_inputs

    def preprocess_function_train(examples):
        max_seq_length = data_args.max_source_length + data_args.max_target_length

        model_inputs = {
            "input_ids": [],
            "labels": [],
        }
        for i in range(len(examples[prompt_column])):
            if examples[prompt_column][i] and examples[response_column][i]:
                query, answer = examples[prompt_column][i], examples[response_column][i]

                if history_column is None:
                    prompt = query
                else:
                    prompt = ""
                    history = examples[history_column][i]
                    for turn_idx, (old_query, response) in enumerate(history):
                        prompt += "[Round {}]\n问:{}\n答:{}\n".format(turn_idx, old_query, response)
                    prompt += "[Round {}]\n问:{}\n答:".format(len(history), query)

                prompt = prefix + prompt
                a_ids = tokenizer.encode(text=prompt, add_special_tokens=False)
                b_ids = tokenizer.encode(text=answer, add_special_tokens=False)

                if len(a_ids) > data_args.max_source_length - 1:
                    a_ids = a_ids[: data_args.max_source_length - 1]

                if len(b_ids) > data_args.max_target_length - 2:
                    b_ids = b_ids[: data_args.max_target_length - 2]

                input_ids = tokenizer.build_inputs_with_special_tokens(a_ids, b_ids)

                context_length = input_ids.index(tokenizer.bos_token_id)
                mask_position = context_length - 1
                labels = [-100] * context_length + input_ids[mask_position+1:]
                
                pad_len = max_seq_length - len(input_ids)
                input_ids = input_ids + [tokenizer.pad_token_id] * pad_len
                labels = labels + [tokenizer.pad_token_id] * pad_len
                if data_args.ignore_pad_token_for_loss:
                    labels = [(l if l != tokenizer.pad_token_id else -100) for l in labels]

                model_inputs["input_ids"].append(input_ids)
                model_inputs["labels"].append(labels)

        return model_inputs

This code is primarily a Python script for training and evaluating a sequence-to-sequence model. It is mainly divided into several parts: importing libraries and modules, defining main functions, and defining preprocessing functions.

The first part is to import the required libraries and modules. Some major libraries include:

  • logging: Used to record log information during program execution .
  • osand sys: Provides functions and variables for interacting with the operating system .
  • json: Process json format data .
  • numpy: For working with numeric arrays and matrices.
  • jieba: Chinese word segmentation tool.
  • datasets: a library for processing and loading datasets.
  • transformers: A widely used NLP library that contains many pre-trained models and tool functions.
  • torch: PyTorch framework for creating and training neural networks.
  • Rougeand sentence_bleu: evaluation metrics.
  • Seq2SeqTrainer: a trainer for training sequence-to-sequence models .
  • ModelArgumentsAnd DataTrainingArguments: some custom parameter classes.

The second part is to define the main function main(). In the main function, it first parses command line arguments or reads arguments . Then, set the log, set the random seed, load the dataset, load the pre-trained model and tokenizer. After that, some additional settings were made on the model, such as model quantization, using half-precision floating point numbers, etc. Finally, according to whether to train, evaluate or predict, select the corresponding dataset and start preprocessing the dataset.

The third part defines two data preprocessing functionspreprocess_function_eval() and preprocess_function_train(). The main function of these two functions is to convert the original text data into an input format acceptable to the model. Specifically, they convert the dialogue transcripts in the raw data into question-answer pairs that the model can understand, and tokenize these question-answer pairs, and then generate the input and labels of the model.

Overall, this code defines a script that can get configuration information from command line arguments or json files, then load datasets and pretrained models, preprocess the data, and train or evaluate the model.

Next is a piece of code that preprocesses the dataset, where different preprocessing functions are selected depending on whether the model is in the training phase or the evaluation phase. Both functions receive samples from the original dataset as input and output a format that the model can accept.

preprocess_function_eval()Will be called if the model is in evaluation phase . This function mainly performs word segmentation on the input data to convert the text into a sequence of numbers that the model can accept, while processing the target label. It will also add the specific prefix, if present.

preprocess_function_train()Will be called if the model is in the training phase . The main function of this function is to process the input text and target text to meet the input requirements of the model. This includes word segmentation, building special tokens (such as BOS , EOS, etc.), and ensuring that the length of the input and tags does not exceed the set maximum length .

After that, depending on whether the model is in the training, evaluation or prediction phase, the corresponding preprocessing function is used to process the data set , and the preprocessed data set is obtained.

The next code segment is mainly to set the training and evaluation configuration of the model, including batch size, learning rate, optimizer, etc., and create a Seq2SeqTrainer object. This Trainer object will handle the training and evaluation of the model.

Then, if specified for model training, the preprocessed training and evaluation datasets are used to train the model, and the model is saved after training.

The following code snippets are mainly used to generate predictions and calculate evaluation metrics during the evaluation and prediction phases. If model evaluation is specified, the preprocessed evaluation data set will be used to evaluate the model, and the ROUGE and BLEU evaluation indicators will be calculated. If prediction is specified, the preprocessed test data set will be used to make predictions on the model and the prediction results will be saved.

Finally, if model prediction is specified , the trained model will be loaded, and the preprocessed test data set will be used to make predictions on the model, and the prediction results will be saved.

At this point, the parsing of the entire script is complete. The script describes a typical model training and evaluation process from beginning to end, including steps such as data preprocessing, model training, model evaluation, and prediction.

 return model_inputs
    
    def print_dataset_example(example):
        print("input_ids",example["input_ids"])
        print("inputs", tokenizer.decode(example["input_ids"]))
        print("label_ids", example["labels"])
        print("labels", tokenizer.decode(example["labels"]))

    if training_args.do_train:
        if "train" not in raw_datasets:
            raise ValueError("--do_train requires a train dataset")
        train_dataset = raw_datasets["train"]
        if data_args.max_train_samples is not None:
            max_train_samples = min(len(train_dataset), data_args.max_train_samples)
            train_dataset = train_dataset.select(range(max_train_samples))
        with training_args.main_process_first(desc="train dataset map pre-processing"):
            train_dataset = train_dataset.map(
                preprocess_function_train,
                batched=True,
                num_proc=data_args.preprocessing_num_workers,
                remove_columns=column_names,
                load_from_cache_file=not data_args.overwrite_cache,
                desc="Running tokenizer on train dataset",
            )
        print_dataset_example(train_dataset[0])

    if training_args.do_eval:
        max_target_length = data_args.val_max_target_length
        if "validation" not in raw_datasets:
            raise ValueError("--do_eval requires a validation dataset")
        eval_dataset = raw_datasets["validation"]
        if data_args.max_eval_samples is not None:
            max_eval_samples = min(len(eval_dataset), data_args.max_eval_samples)
            eval_dataset = eval_dataset.select(range(max_eval_samples))
        with training_args.main_process_first(desc="validation dataset map pre-processing"):
            eval_dataset = eval_dataset.map(
                preprocess_function_eval,
                batched=True,
                num_proc=data_args.preprocessing_num_workers,
                remove_columns=column_names,
                load_from_cache_file=not data_args.overwrite_cache,
                desc="Running tokenizer on validation dataset",
            )
        print_dataset_example(eval_dataset[0])

    if training_args.do_predict:
        max_target_length = data_args.val_max_target_length
        if "test" not in raw_datasets:
            raise ValueError("--do_predict requires a test dataset")
        predict_dataset = raw_datasets["test"]
        if data_args.max_predict_samples is not None:
            max_predict_samples = min(len(predict_dataset), data_args.max_predict_samples)
            predict_dataset = predict_dataset.select(range(max_predict_samples))
        with training_args.main_process_first(desc="prediction dataset map pre-processing"):
            predict_dataset = predict_dataset.map(
                preprocess_function_eval,
                batched=True,
                num_proc=data_args.preprocessing_num_workers,
                remove_columns=column_names,
                load_from_cache_file=not data_args.overwrite_cache,
                desc="Running tokenizer on prediction dataset",
            )
        print_dataset_example(predict_dataset[0])

    # Data collator
    label_pad_token_id = -100 if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id
    data_collator = DataCollatorForSeq2Seq(
        tokenizer,
        model=model,
        label_pad_token_id=label_pad_token_id,
        pad_to_multiple_of=None,
        padding=False
    )

    # Metric
    def compute_metrics(eval_preds):
        preds, labels = eval_preds
        if isinstance(preds, tuple):
            preds = preds[0]
        decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
        if data_args.ignore_pad_token_for_loss:
            # Replace -100 in the labels as we can't decode them.
            labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
        decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

        score_dict = {
            "rouge-1": [],
            "rouge-2": [],
            "rouge-l": [],
            "bleu-4": []
        }
        for pred, label in zip(decoded_preds, decoded_labels):
            hypothesis = list(jieba.cut(pred))
            reference = list(jieba.cut(label))
            rouge = Rouge()
            scores = rouge.get_scores(' '.join(hypothesis) , ' '.join(reference))
            result = scores[0]
            
            for k, v in result.items():
                score_dict[k].append(round(v["f"] * 100, 4))
            bleu_score = sentence_bleu([list(label)], list(pred), smoothing_function=SmoothingFunction().method3)
            score_dict["bleu-4"].append(round(bleu_score * 100, 4))

        for k, v in score_dict.items():
            score_dict[k] = float(np.mean(v))
        return score_dict

    # Override the decoding parameters of Seq2SeqTrainer
    training_args.generation_max_length = (
        training_args.generation_max_length
        if training_args.generation_max_length is not None
        else data_args.val_max_target_length
    )
    training_args.generation_num_beams = (
        data_args.num_beams if data_args.num_beams is not None else training_args.generation_num_beams
    )
    # Initialize our Trainer
    trainer = Seq2SeqTrainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset if training_args.do_train else None,
        eval_dataset=eval_dataset if training_args.do_eval else None,
        tokenizer=tokenizer,
        data_collator=data_collator,
        compute_metrics=compute_metrics if training_args.predict_with_generate else None,
        save_prefixencoder=model_args.pre_seq_len is not None
    )

    # Training
    if training_args.do_train:
        checkpoint = None
        if training_args.resume_from_checkpoint is not None:
            checkpoint = training_args.resume_from_checkpoint
        # elif last_checkpoint is not None:
        #     checkpoint = last_checkpoint
        model.gradient_checkpointing_enable()
        model.enable_input_require_grads()
        train_result = trainer.train(resume_from_checkpoint=checkpoint)
        # trainer.save_model()  # Saves the tokenizer too for easy upload

        metrics = train_result.metrics
        max_train_samples = (
            data_args.max_train_samples if data_args.max_train_samples is not None else len(train_dataset)
        )
        metrics["train_samples"] = min(max_train_samples, len(train_dataset))

        trainer.log_metrics("train", metrics)
        trainer.save_metrics("train", metrics)
        trainer.save_state()

    # Evaluation
    results = {}
    max_seq_length = data_args.max_source_length + data_args.max_target_length + 1
    if training_args.do_eval:
        logger.info("*** Evaluate ***")
        metrics = trainer.evaluate(metric_key_prefix="eval", do_sample=True, top_p=0.7, max_length=max_seq_length, temperature=0.95)
        max_eval_samples = data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset)
        metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset))

        trainer.log_metrics("eval", metrics)
        trainer.save_metrics("eval", metrics)

    if training_args.do_predict:
        logger.info("*** Predict ***")
        predict_results = trainer.predict(predict_dataset, metric_key_prefix="predict", max_length=max_seq_length, do_sample=True, top_p=0.7, temperature=0.95)
        metrics = predict_results.metrics
        max_predict_samples = (
            data_args.max_predict_samples if data_args.max_predict_samples is not None else len(predict_dataset)
        )
        metrics["predict_samples"] = min(max_predict_samples, len(predict_dataset))

        trainer.log_metrics("predict", metrics)
        trainer.save_metrics("predict", metrics)

        if trainer.is_world_process_zero():
            if training_args.predict_with_generate:
                predictions = tokenizer.batch_decode(
                    predict_results.predictions, skip_special_tokens=True, clean_up_tokenization_spaces=True
                )
                predictions = [pred.strip() for pred in predictions]
                labels = tokenizer.batch_decode(
                    predict_results.label_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True
                )
                labels = [label.strip() for label in labels]
                output_prediction_file = os.path.join(training_args.output_dir, "generated_predictions.txt")
                with open(output_prediction_file, "w", encoding="utf-8") as writer:
                    for p, l in zip(predictions, labels):
                        res = json.dumps({"labels": l, "predict": p}, ensure_ascii=False)
                        writer.write(f"{res}\n")
    return results


def _mp_fn(index):
    # For xla_spawn (TPUs)
    main()


if __name__ == "__main__":
    main()

This is a code for machine learning model training and evaluation, mainly for sequence-to-sequence (seq2seq) models. Below I explain this code line by line:

1-5: A function that prints a sample dataset is defined . For the input sample, the function will print out the original input id , input string, label id, and label string respectively .

7-25: If training, first check if a training dataset is provided , and throw an exception if not. Next, the training data set is preprocessed, which includes tokenization of the input data set (converting input text into tokens that the model can understand). If the maximum number of training samples is set, select the corresponding number of samples from the data set for training. Finally, print out the first sample of the preprocessed training dataset.

27-44: If evaluated, first check if a validation dataset is provided, and throw an exception if not. Next, the validation dataset is preprocessed, which is similar to the training dataset. Finally, print out the first sample of the preprocessed validation dataset.

46-63: If making a prediction, first check if a test dataset is provided , and throw an exception if not. Next, the test dataset is preprocessed, similarly to the training and validation datasets. Finally, print out the first sample of the preprocessed test dataset.

65-70: Defines the data collator (data collator), which is a tool for batch processing data. It will perform necessary padding and cropping on the input data so that they can be organized into the same length, so that they can be input to model.

72-102: Defines a function to calculate evaluation indicators , which can calculate evaluation indicators such as ROUGE and BLEU based on prediction results and actual labels.

104-112: Some necessary settings are made on the training parameters, such as the maximum length of the generated sequence and the number of generated beams (for beam search decoding).

114-123: The trainer is initialized, which is a tool for training and evaluating the model, which accepts the model, training parameters, dataset, tokenizer, data organizer, etc. as input.

125-140: If training is performed, the train method of the trainer will be called for training, and the model and training indicators will be saved.

142-152: If evaluation is performed, the evaluate method of the trainer will be called for evaluation and the evaluation index will be saved.

154-179: If prediction is made, the predict method of the trainer will be called to make predictions, and the prediction indicators and prediction results will be saved.

181-183: An auxiliary function is defined to distribute the execution of the main function to each TPU core when using TPU for distributed training.

185-187: If you run this script directly, the main function will be called.

The main logic of this script is in maina function, but this function is not shown in this code snippet. Depending on the organization of its code, maina function may contain the following steps:

  1. Parsing command-line arguments: Define and handle arguments for training, evaluation, prediction, etc. These parameters may include number of epochs for training, learning rate, choice of optimizer, etc.

  2. Load datasets: Training, validation and test data may be loaded from the specified path.

  3. Load model: load a pre-trained model or a custom model.

  4. Call various functions defined before: e.g. for training, evaluating, predicting, etc.

  5. Save Model: After training, save the state of the model.

_mp_fnThe function is used during TPU training, which can help distribute tasks on each TPU core.

In general, this code snippet is a typical deep learning training and evaluation script, and the specific functions and processes may mainvary depending on the implementation of the function.

This code is roughly performing a machine learning task, especially the task involving sequence-to-sequence (Seq2Seq) model training, verification and prediction. I'll explain the code line by line:

  1. def print_dataset_example(example):Defines a function that prints out the individual components of a dataset sample. This is mainly used to understand the structure and content of the dataset.

  2. print("input_ids",example["input_ids"])Print the input ID of the sample.

  3. print("inputs", tokenizer.decode(example["input_ids"]))Decode the input ID through the tokenizer to see the original input.

  4. print("label_ids", example["labels"])Print the label ID of the sample.

  5. print("labels", tokenizer.decode(example["labels"]))Decode the tag ID through the tokenizer to see the original tag.

  6. if training_args.do_train:Determine whether to perform model training.

  7. if "train" not in raw_datasets:Throws an error if no training dataset is provided.

  8. train_dataset = raw_datasets["train"]Get the training dataset.

  9. if data_args.max_train_samples is not None:Determine whether the maximum number of training samples is set.

  10. max_train_samples = min(len(train_dataset), data_args.max_train_samples)Set the maximum number of training samples.

  11. train_dataset = train_dataset.select(range(max_train_samples))Downsample the training dataset to control the number of training samples.

  12. with training_args.main_process_first(desc="train dataset map pre-processing"):When preprocessing the training dataset, the main process is processed first.

  13. train_dataset = train_dataset.map(...)Preprocess the training dataset.

  14. print_dataset_example(train_dataset[0])Print a sample of the training dataset.

  15. if training_args.do_eval:Determine whether to perform model validation.

  16. Similar to processing the training dataset, this part of the code preprocesses the validation dataset and then prints a sample.

  17. if training_args.do_predict:Determine whether to perform model prediction.

  18. Similar to processing the training and validation datasets, this part of the code preprocesses the prediction dataset and then prints a sample.

  19. label_pad_token_id = -100 if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_idSet the label to fill the ID of the token.

  20. data_collator = DataCollatorForSeq2Seq(...)Initialize a data collator for batch processing data.

  21. def compute_metrics(eval_preds):Define a function for computing evaluation metrics.

  22. decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)Batch decoding of predictions.

  23. decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)Batch decoding of ground truth labels.

  24. score_dict = {...}Initialize the scoring dictionary.

  25. In the loop, one prediction and one true label are processed at a time, their Rouge and BLEU scores are calculated, and the results are added to the scoring dictionary.

  26. for k, v in score_dict.items():Perform an averaging operation on each score.

  27. return score_dictReturns all calculated ratings.

  28. training_args.generation_max_length = ...Sets the maximum length of generated sequences.

  29. training_args.generation_num_beams = ...Sets the beam search width when generating sequences.

  30. trainer = Seq2SeqTrainer(...)Initialize a Seq2Seq trainer.

  31. if training_args.do_train:If model training is performed, start the training process of the trainer and save the training results.

  32. results = {}Initialize the result dictionary.

  33. if training_args.do_eval:If model verification is performed, start the verification process of the trainer and save the verification results.

  34. if training_args.do_predict:If model prediction is performed, start the prediction process of the trainer and save the prediction result.

  35. return resultsReturn all results.

  36. def _mp_fn(index):Define a function for multi-process operation.

  37. main()Call the main function.

  38. if __name__ == "__main__":Determine whether to run the script directly, and if so, call the main function.

The above is a line-by-line explanation of this code.

Guess you like

Origin blog.csdn.net/sinat_37574187/article/details/131395980